From 5ae2cc84876149f05600767f93ae9f6aa49c3ed0 Mon Sep 17 00:00:00 2001 From: Louise Davies Date: Thu, 28 May 2020 09:43:29 +0100 Subject: [PATCH] #125 - add abstract "backend" class and move database specific code in endpoints to there. Also, an change in that errors are no longer caught, and are translated by flask into error codes. --- common/backends.py | 317 ++++++++++++++++++ common/config.py | 6 + common/database_helpers.py | 32 +- common/exceptions.py | 33 +- common/helpers.py | 85 +++-- common/models/db_models.py | 144 +++++--- config.json.example | 1 + src/main.py | 32 +- src/resources/entities/entity_endpoint.py | 39 +-- .../non_entities/sessions_endpoints.py | 25 +- .../table_endpoints/table_endpoints.py | 31 +- test/test_base.py | 2 +- test/test_helpers.py | 95 ++++-- 13 files changed, 633 insertions(+), 209 deletions(-) create mode 100644 common/backends.py diff --git a/common/backends.py b/common/backends.py new file mode 100644 index 00000000..03b2b3a6 --- /dev/null +++ b/common/backends.py @@ -0,0 +1,317 @@ +from abc import ABC, abstractmethod +from common.database_helpers import get_investigations_for_user, get_investigations_for_user_count, \ + get_facility_cycles_for_instrument, get_facility_cycles_for_instrument_count, \ + get_investigations_for_instrument_in_facility_cycle, get_investigations_for_instrument_in_facility_cycle_count, \ + get_rows_by_filter, create_rows_from_json, patch_entities, get_row_by_id, insert_row_into_table, \ + delete_row_by_id, update_row_from_id, get_filtered_row_count, get_first_filtered_row +from common.helpers import requires_session_id, queries_records +from common.models.db_models import SESSION +from common.config import config +import uuid +import sys +from common.exceptions import AuthenticationError + + +class Backend(ABC): + """ + Abstact base class for implementations of a backend to inherit from + """ + + @abstractmethod + def login(self, credentials): + """ + Attempt to log a user in using the provided credentials + :param credentials: The user's credentials + :returns: a session ID + """ + pass + + @abstractmethod + def get_session_details(self, session_id): + """ + Returns the details of a user's session + :param session_id: The user's session ID + :returns: The user's session details + """ + pass + + @abstractmethod + def refresh(self, session_id): + """ + Attempts to refresh a user's session + :param session_id: The user's session ID + :returns: the user's refreshed session ID + """ + pass + + @abstractmethod + def logout(self, session_id): + """ + Logs a user out + :param session_id: The user's session ID + """ + pass + + @abstractmethod + def get_with_filters(self, session_id, entity_type, filters): + """ + Given a list of filters supplied in json format, returns entities that match the filters for the given entity type + :param session_id: The session id of the requesting user + :param entity_type: The type of entity + :param filters: The list of filters to be applied + :return: A list of the matching entities in json format + """ + pass + + @abstractmethod + def create(self, session_id, entity_type, data): + """ + Create one or more entities, from the given list containing json. Each entity must not contain its ID + :param session_id: The session id of the requesting user + :param entity_type: The type of entity + :param data: The entities to be created + :return: The created entities. + """ + pass + + @abstractmethod + def update(self, session_id, entity_type, data): + """ + Update one or more entities, from the given list containing json. Each entity must contain its ID + :param session_id: The session id of the requesting user + :param entity_type: The type of entity + :param data: the list of updated values or a dictionary + :return: The list of updated entities. + """ + pass + + @abstractmethod + def get_one_with_filters(self, session_id, entity_type, filters): + """ + returns the first entity that matches a given filter, for a given entity type + :param session_id: The session id of the requesting user + :param entity_type: The type of entity + :param filters: the filter to be applied to the query + :return: the first entity matching the filter + """ + pass + + @abstractmethod + def count_with_filters(self, session_id, entity_type, filters): + """ + returns the count of the entities that match a given filter for a given entity type + :param session_id: The session id of the requesting user + :param entity_type: The type of entity + :param filters: the filters to be applied to the query + :return: int: the count of the entities + """ + pass + + @abstractmethod + def get_with_id(self, session_id, entity_type, id): + """ + Gets the entity matching the given ID for the given entity type + :param session_id: The session id of the requesting user + :param entity_type: The type of entity + :param id: the id of the record to find + :return: the entity retrieved + """ + pass + + @abstractmethod + def delete_with_id(self, session_id, entity_type, id): + """ + Deletes the row matching the given ID for the given entity type + :param session_id: The session id of the requesting user + :param table: the table to be searched + :param id: the id of the record to delete + """ + pass + + @abstractmethod + def update_with_id(self, session_id, entity_type, id, data): + """ + Updates the row matching the given ID for the given entity type + :param session_id: The session id of the requesting user + :param entity_type: The type of entity + :param data: The dictionary that the entity should be updated with + :return: The updated entity. + """ + pass + + @abstractmethod + def get_users_investigations_with_filters(self, session_id, user_id, filters): + """ + Given a user id and a list of filters, return a filtered list of all investigations that belong to that user + :param session_id: The session id of the requesting user + :param user_id: The id of the user + :param filters: The list of filters + :return: A list of dictionary representations of the investigation entities + """ + pass + + @abstractmethod + def count_users_investigations_with_filters(self, session_id, user_id, filters): + """ + Given a user id and a list of filters, return the count of all investigations that belong to that user + :param session_id: The session id of the requesting user + :param user_id: The id of the user + :param filters: The list of filters + :return: The count + """ + pass + + @abstractmethod + def get_instrument_facilitycycles_with_filters(self, session_id, instrument_id, filters): + """ + Given an instrument_id get facility cycles where the instrument has investigations that occur within that cycle + :param session_id: The session id of the requesting user + :param filters: The filters to be applied to the query + :param instrument_id: The id of the instrument + :return: A list of facility cycle entities + """ + pass + + @abstractmethod + def count_instrument_facilitycycles_with_filters(self, session_id, instrument_id, filters): + """ + Given an instrument_id get the facility cycles count where the instrument has investigations that occur within + that cycle + :param session_id: The session id of the requesting user + :param filters: The filters to be applied to the query + :param instrument_id: The id of the instrument + :return: The count of the facility cycles + """ + pass + + @abstractmethod + def get_instrument_facilitycycle_investigations_with_filters(self, session_id, instrument_id, facilitycycle_id, filters): + """ + Given an instrument id and facility cycle id, get investigations that use the given instrument in the given cycle + :param session_id: The session id of the requesting user + :param filters: The filters to be applied to the query + :param instrument_id: The id of the instrument + :param facility_cycle_id: the ID of the facility cycle + :return: The investigations + """ + pass + + @abstractmethod + def count_instrument_facilitycycles_investigations_with_filters(self, session_id, instrument_id, facilitycycle_id, filters): + """ + Given an instrument id and facility cycle id, get the count of the investigations that use the given instrument in + the given cycle + :param session_id: The session id of the requesting user + :param filters: The filters to be applied to the query + :param instrument_id: The id of the instrument + :param facility_cycle_id: the ID of the facility cycle + :return: The investigations count + """ + pass + + +class DatabaseBackend(Backend): + """ + Class that contains functions to access and modify data in an ICAT database directly + """ + + def login(self, credentials): + if credentials["username"] == "user" and credentials["password"] == "password": + session_id = str(uuid.uuid1()) + insert_row_into_table(SESSION, SESSION(ID=session_id)) + return session_id + else: + raise AuthenticationError("Username and password are incorrect") + + @requires_session_id + def get_session_details(self, session_id): + return get_row_by_id(SESSION, session_id) + + @requires_session_id + def refresh(self, session_id): + return session_id + + @requires_session_id + @queries_records + def logout(self, session_id): + return delete_row_by_id(SESSION, session_id) + + @requires_session_id + @queries_records + def get_with_filters(self, session_id, table, filters): + return get_rows_by_filter(table, filters) + + @requires_session_id + @queries_records + def create(self, session_id, table, data): + return create_rows_from_json(table, data) + + @requires_session_id + @queries_records + def update(self, session_id, table, data): + return patch_entities(table, data) + + @requires_session_id + @queries_records + def get_one_with_filters(self, session_id, table, filters): + return get_first_filtered_row(table, filters) + + @requires_session_id + @queries_records + def count_with_filters(self, session_id, table, filters): + return get_filtered_row_count(table, filters) + + @requires_session_id + @queries_records + def get_with_id(self, session_id, table, id): + return get_row_by_id(table, id).to_dict() + + @requires_session_id + @queries_records + def delete_with_id(self, session_id, table, id): + return delete_row_by_id(table, id) + + @requires_session_id + @queries_records + def update_with_id(self, session_id, table, id, data): + return update_row_from_id(table, id, data) + + @requires_session_id + @queries_records + def get_users_investigations_with_filters(self, session_id, user_id, filters): + return get_investigations_for_user(user_id, filters) + + @requires_session_id + @queries_records + def count_users_investigations_with_filters(self, session_id, user_id, filters): + return get_investigations_for_user_count(user_id, filters) + + @requires_session_id + @queries_records + def get_instrument_facilitycycles_with_filters(self, session_id, instrument_id, filters): + return get_facility_cycles_for_instrument(instrument_id, filters) + + @requires_session_id + @queries_records + def count_instrument_facilitycycles_with_filters(self, session_id, instrument_id, filters): + return get_facility_cycles_for_instrument_count(instrument_id, filters) + + @requires_session_id + @queries_records + def get_instrument_facilitycycle_investigations_with_filters(self, session_id, instrument_id, facilitycycle_id, filters): + return get_investigations_for_instrument_in_facility_cycle(instrument_id, facilitycycle_id, filters) + + @requires_session_id + @queries_records + def count_instrument_facilitycycles_investigations_with_filters(self, session_id, instrument_id, facilitycycle_id, filters): + return get_investigations_for_instrument_in_facility_cycle_count(instrument_id, facilitycycle_id, filters) + + +backend_type = config.get_backend_type() + +if backend_type == "db": + backend = DatabaseBackend() +else: + sys.exit( + f"Invalid config value '{backend_type}' for config option backend") + backend = Backend() diff --git a/common/config.py b/common/config.py index 720f72a1..53643c1b 100644 --- a/common/config.py +++ b/common/config.py @@ -12,6 +12,12 @@ def __init__(self): self.config = json.load(target) target.close() + def get_backend_type(self): + try: + return self.config["backend"] + except: + sys.exit("Missing config value, backend") + def get_db_url(self): try: return self.config["DB_URL"] diff --git a/common/database_helpers.py b/common/database_helpers.py index d942b1f7..d17d2cb8 100644 --- a/common/database_helpers.py +++ b/common/database_helpers.py @@ -177,8 +177,10 @@ def apply_filter(self, query): if self.included_included_field: included_table = getattr(db_models, self.field) included_included_table = getattr(db_models, self.included_field) - query.base_query = query.base_query.join(included_table).join(included_included_table) - field = getattr(included_included_table, self.included_included_field) + query.base_query = query.base_query.join( + included_table).join(included_included_table) + field = getattr(included_included_table, + self.included_included_field) elif self.included_field: included_table = getattr(db_models, self.field) @@ -188,7 +190,8 @@ def apply_filter(self, query): if self.operation == "eq": query.base_query = query.base_query.filter(field == self.value) elif self.operation == "like": - query.base_query = query.base_query.filter(field.like(f"%{self.value}%")) + query.base_query = query.base_query.filter( + field.like(f"%{self.value}%")) elif self.operation == "lte": query.base_query = query.base_query.filter(field <= self.value) elif self.operation == "gte": @@ -196,19 +199,22 @@ def apply_filter(self, query): elif self.operation == "in": query.base_query = query.base_query.filter(field.in_(self.value)) else: - raise BadFilterError(f" Bad operation given to where filter. operation: {self.operation}") + raise BadFilterError( + f" Bad operation given to where filter. operation: {self.operation}") class DistinctFieldFilter(QueryFilter): precedence = 0 def __init__(self, fields): - self.fields = fields if type(fields) is list else [fields] # This allows single string distinct filters + # This allows single string distinct filters + self.fields = fields if type(fields) is list else [fields] def apply_filter(self, query): query.is_distinct_fields_query = True try: - self.fields = [getattr(query.table, field) for field in self.fields] + self.fields = [getattr(query.table, field) + for field in self.fields] except AttributeError: raise BadFilterError("Bad field requested") query.base_query = query.session.query(*self.fields).distinct() @@ -223,9 +229,11 @@ def __init__(self, field, direction): def apply_filter(self, query): if self.direction.upper() == "ASC": - query.base_query = query.base_query.order_by(asc(self.field.upper())) + query.base_query = query.base_query.order_by( + asc(self.field.upper())) elif self.direction.upper() == "DESC": - query.base_query = query.base_query.order_by(desc(self.field.upper())) + query.base_query = query.base_query.order_by( + desc(self.field.upper())) else: raise BadFilterError(f" Bad filter: {self.direction}") @@ -260,7 +268,7 @@ def apply_filter(self, query): if not query.include_related_entities: query.include_related_entities = True else: - raise MultipleIncludeError("Attempted multiple includes on a single query") + raise MultipleIncludeError() class QueryFilterFactory(object): @@ -515,7 +523,8 @@ class UserInvestigationsQuery(ReadQuery): def __init__(self, user_id): super().__init__(INVESTIGATION) - self.base_query = self.base_query.join(INVESTIGATIONUSER).filter(INVESTIGATIONUSER.USER_ID == user_id) + self.base_query = self.base_query.join(INVESTIGATIONUSER).filter( + INVESTIGATIONUSER.USER_ID == user_id) def get_investigations_for_user(user_id, filters): @@ -537,7 +546,8 @@ class UserInvestigationsCountQuery(CountQuery): def __init__(self, user_id): super().__init__(INVESTIGATION) - self.base_query = self.base_query.join(INVESTIGATIONUSER).filter(INVESTIGATIONUSER.USER_ID == user_id) + self.base_query = self.base_query.join(INVESTIGATIONUSER).filter( + INVESTIGATIONUSER.USER_ID == user_id) def get_investigations_for_user_count(user_id, filters): diff --git a/common/exceptions.py b/common/exceptions.py index 369ad2de..f0340aa4 100644 --- a/common/exceptions.py +++ b/common/exceptions.py @@ -1,30 +1,47 @@ +import werkzeug + + class ApiError(Exception): - pass + status_code = 500 class MissingRecordError(ApiError): - pass + def __init__(self, msg='No such record in table', *args, **kwargs): + super().__init__(msg, *args, **kwargs) + self.status_code = 404 class BadFilterError(ApiError): - pass + def __init__(self, msg='Invalid filter requested', *args, **kwargs): + super().__init__(msg, *args, **kwargs) + self.status_code = 400 class MultipleIncludeError(BadFilterError): - pass + def __init__(self, msg='Bad request, only one include filter may be given per request', *args, **kwargs): + super().__init__(msg, *args, **kwargs) + self.status_code = 400 class AuthenticationError(ApiError): - pass + def __init__(self, msg='Authentication error', *args, **kwargs): + super().__init__(msg, *args, **kwargs) + self.status_code = 403 class MissingCredentialsError(AuthenticationError): - pass + def __init__(self, msg='No credentials provided in auth header', *args, **kwargs): + super().__init__(msg, *args, **kwargs) + self.status_code = 401 class BadRequestError(ApiError): - pass + def __init__(self, msg='Bad request', *args, **kwargs): + super().__init__(msg, *args, **kwargs) + self.status_code = 400 class DatabaseError(ApiError): - pass + def __init__(self, msg='Database error', *args, **kwargs): + super().__init__(msg, *args, **kwargs) + self.status_code = 500 diff --git a/common/helpers.py b/common/helpers.py index 6d4bca15..2b924a9e 100644 --- a/common/helpers.py +++ b/common/helpers.py @@ -7,8 +7,7 @@ from sqlalchemy.exc import IntegrityError from common.database_helpers import QueryFilterFactory -from common.exceptions import MissingRecordError, BadFilterError, AuthenticationError, BadRequestError, \ - MissingCredentialsError, MultipleIncludeError +from common.exceptions import ApiError, AuthenticationError, BadFilterError, BadRequestError, MissingCredentialsError, MissingRecordError, MultipleIncludeError from common.models.db_models import SESSION from common.session_manager import session_manager @@ -17,34 +16,28 @@ def requires_session_id(method): """ - Decorator for endpoint resources that makes sure a valid session_id is provided in requests to that endpoint - :param method: The method for the endpoint - :returns a 403, "Forbidden" if a valid session_id is not provided with the request + Decorator for database backend methods that makes sure a valid session_id is provided + It expects that session_id is the second argument supplied to the function + :param method: The method for the backend operation + :raises AuthenticationError, if a valid session_id is not provided with the request """ - @wraps(method) def wrapper_requires_session(*args, **kwargs): log.info(" Authenticating consumer") - try: - session = session_manager.get_icat_db_session() - query = session.query(SESSION).filter( - SESSION.ID == get_session_id_from_auth_header()).first() - if query is not None: - log.info(" Closing DB session") - session.close() - session.close() - log.info(" Consumer authenticated") - return method(*args, **kwargs) - else: - log.info(" Could not authenticate consumer, closing DB session") - session.close() - return "Forbidden", 403 - except MissingCredentialsError: - return "Unauthorized", 401 - except AuthenticationError: - return "Forbidden", 403 - + session = session_manager.get_icat_db_session() + query = session.query(SESSION).filter( + SESSION.ID == args[1]).first() + if query is not None: + log.info(" Closing DB session") + session.close() + session.close() + log.info(" Consumer authenticated") + return method(*args, **kwargs) + else: + log.info(" Could not authenticate consumer, closing DB session") + session.close() + raise AuthenticationError("Forbidden") return wrapper_requires_session @@ -60,27 +53,18 @@ def queries_records(method): def wrapper_gets_records(*args, **kwargs): try: return method(*args, **kwargs) - except MissingRecordError as e: + except ApiError as e: log.exception(e) - return "No such record in table", 404 - except MultipleIncludeError as e: - log.exception(e) - return "Bad request, only one include filter may be given per request", 400 - except BadFilterError as e: - log.exception(e) - return "Invalid filter requested", 400 + raise e except ValueError as e: log.exception(e) - return "Bad request", 400 + raise BadRequestError() except TypeError as e: log.exception(e) - return "Bad request", 400 + raise BadRequestError() except IntegrityError as e: log.exception(e) - return "Bad request", 400 - except BadRequestError as e: - log.exception(e) - return "Bad request", 400 + raise BadRequestError() return wrapper_gets_records @@ -93,11 +77,14 @@ def get_session_id_from_auth_header(): parser = reqparse.RequestParser() parser.add_argument("Authorization", location="headers") args = parser.parse_args() - auth_header = args["Authorization"].split(" ") if args["Authorization"] is not None else "" + auth_header = args["Authorization"].split( + " ") if args["Authorization"] is not None else "" if auth_header == "": - raise MissingCredentialsError(f"No credentials provided in auth header") + raise MissingCredentialsError( + f"No credentials provided in auth header") if len(auth_header) != 2 or auth_header[0] != "Bearer": - raise AuthenticationError(f" Could not authenticate consumer with auth header {auth_header}") + raise AuthenticationError( + f" Could not authenticate consumer with auth header {auth_header}") return auth_header[1] @@ -122,8 +109,12 @@ def get_filters_from_query_string(): :return: The list of filters """ log.info(" Getting filters from query string") - filters = [] - for arg in request.args: - for value in request.args.getlist(arg): - filters.append(QueryFilterFactory.get_query_filter({arg: json.loads(value)})) - return filters + try: + filters = [] + for arg in request.args: + for value in request.args.getlist(arg): + filters.append(QueryFilterFactory.get_query_filter( + {arg: json.loads(value)})) + return filters + except: + raise BadFilterError() diff --git a/common/models/db_models.py b/common/models/db_models.py index 23b45696..2c2ceb9c 100644 --- a/common/models/db_models.py +++ b/common/models/db_models.py @@ -31,9 +31,12 @@ def process_bind_param(self, value, dialect): def process_result_value(self, value, dialect): try: - return f"{self.enum_type(value)}".replace(f"{self.enum_type.__name__}.", "") # Strips the enum class name + # Strips the enum class name + return f"{self.enum_type(value)}".replace(f"{self.enum_type.__name__}.", "") except ValueError: - raise DatabaseError(f"value {value} not in {self.enum_type.__name__}") # This will force a 500 response + # This will force a 500 response + raise DatabaseError( + f"value {value} not in {self.enum_type.__name__}") def copy(self, **kwargs): return EnumAsInteger(self.enum_type) @@ -83,7 +86,8 @@ def to_nested_dict(self, includes): elif type(include) is dict: self._nest_dictionary_include(dictionary, include) except TypeError: - raise BadFilterError(f" Bad include relations provided: {includes}") + raise BadFilterError( + f" Bad include relations provided: {includes}") return dictionary def _nest_dictionary_include(self, dictionary, include): @@ -95,13 +99,16 @@ def _nest_dictionary_include(self, dictionary, include): """ related_entity = self.get_related_entity(list(include)[0]) if not isinstance(related_entity, InstrumentedList): - dictionary[related_entity.__tablename__] = related_entity.to_nested_dict(include[list(include)[0]]) + dictionary[related_entity.__tablename__] = related_entity.to_nested_dict( + include[list(include)[0]]) else: for entity in related_entity: if entity.__tablename__ in dictionary.keys(): - dictionary[entity.__tablename__].append(entity.to_nested_dict(include[list(include)[0]])) + dictionary[entity.__tablename__].append( + entity.to_nested_dict(include[list(include)[0]])) else: - dictionary[entity.__tablename__] = [entity.to_nested_dict(include[list(include)[0]])] + dictionary[entity.__tablename__] = [ + entity.to_nested_dict(include[list(include)[0]])] def _nest_string_include(self, dictionary, include): """ @@ -135,9 +142,11 @@ def update_from_dict(self, dictionary): """ Given a dictionary containing field names and variables, updates the entity from the given dictionary :param dictionary: dict: dictionary containing the new values + :returns: The updated dict """ for key in dictionary: setattr(self, key.upper(), dictionary[key]) + return self.to_dict() class APPLICATION(Base, EntityHelper): @@ -155,7 +164,8 @@ class APPLICATION(Base, EntityHelper): VERSION = Column(String(255), nullable=False) FACILITY_ID = Column(ForeignKey('FACILITY.ID'), nullable=False) - FACILITY = relationship('FACILITY', primaryjoin='APPLICATION.FACILITY_ID == FACILITY.ID', backref='APPLICATION') + FACILITY = relationship( + 'FACILITY', primaryjoin='APPLICATION.FACILITY_ID == FACILITY.ID', backref='APPLICATION') class FACILITY(Base, EntityHelper): @@ -187,7 +197,8 @@ class DATACOLLECTION(Base, EntityHelper): class DATACOLLECTIONDATAFILE(Base, EntityHelper): __tablename__ = 'DATACOLLECTIONDATAFILE' __table_args__ = ( - Index('UNQ_DATACOLLECTIONDATAFILE_0', 'DATACOLLECTION_ID', 'DATAFILE_ID'), + Index('UNQ_DATACOLLECTIONDATAFILE_0', + 'DATACOLLECTION_ID', 'DATAFILE_ID'), ) ID = Column(BigInteger, primary_key=True) @@ -208,7 +219,8 @@ class DATACOLLECTIONDATAFILE(Base, EntityHelper): class DATACOLLECTIONDATASET(Base, EntityHelper): __tablename__ = 'DATACOLLECTIONDATASET' __table_args__ = ( - Index('UNQ_DATACOLLECTIONDATASET_0', 'DATACOLLECTION_ID', 'DATASET_ID'), + Index('UNQ_DATACOLLECTIONDATASET_0', + 'DATACOLLECTION_ID', 'DATASET_ID'), ) ID = Column(BigInteger, primary_key=True) @@ -229,7 +241,8 @@ class DATACOLLECTIONDATASET(Base, EntityHelper): class DATACOLLECTIONPARAMETER(Base, EntityHelper): __tablename__ = 'DATACOLLECTIONPARAMETER' __table_args__ = ( - Index('UNQ_DATACOLLECTIONPARAMETER_0', 'DATACOLLECTION_ID', 'PARAMETER_TYPE_ID'), + Index('UNQ_DATACOLLECTIONPARAMETER_0', + 'DATACOLLECTION_ID', 'PARAMETER_TYPE_ID'), ) ID = Column(BigInteger, primary_key=True) @@ -244,7 +257,8 @@ class DATACOLLECTIONPARAMETER(Base, EntityHelper): RANGETOP = Column(Float(asdecimal=True)) STRING_VALUE = Column(String(4000)) DATACOLLECTION_ID = Column(ForeignKey('DATACOLLECTION.ID'), nullable=False) - PARAMETER_TYPE_ID = Column(ForeignKey('PARAMETERTYPE.ID'), nullable=False, index=True) + PARAMETER_TYPE_ID = Column(ForeignKey( + 'PARAMETERTYPE.ID'), nullable=False, index=True) DATACOLLECTION = relationship('DATACOLLECTION', primaryjoin='DATACOLLECTIONPARAMETER.DATACOLLECTION_ID == DATACOLLECTION.ID', @@ -278,7 +292,8 @@ class DATAFILE(Base, EntityHelper): DATAFILEFORMAT = relationship('DATAFILEFORMAT', primaryjoin='DATAFILE.DATAFILEFORMAT_ID == DATAFILEFORMAT.ID', backref='DATAFILE') - DATASET = relationship('DATASET', primaryjoin='DATAFILE.DATASET_ID == DATASET.ID', backref='DATAFILE') + DATASET = relationship( + 'DATASET', primaryjoin='DATAFILE.DATASET_ID == DATASET.ID', backref='DATAFILE') class DATAFILEFORMAT(Base, EntityHelper): @@ -320,7 +335,8 @@ class DATAFILEPARAMETER(Base, EntityHelper): RANGETOP = Column(Float(asdecimal=True)) STRING_VALUE = Column(String(4000)) DATAFILE_ID = Column(ForeignKey('DATAFILE.ID'), nullable=False) - PARAMETER_TYPE_ID = Column(ForeignKey('PARAMETERTYPE.ID'), nullable=False, index=True) + PARAMETER_TYPE_ID = Column(ForeignKey( + 'PARAMETERTYPE.ID'), nullable=False, index=True) DATAFILE = relationship('DATAFILE', primaryjoin='DATAFILEPARAMETER.DATAFILE_ID == DATAFILE.ID', backref='DATAFILEPARAMETER') @@ -352,8 +368,10 @@ class DATASET(Base, EntityHelper): INVESTIGATION = relationship('INVESTIGATION', primaryjoin='DATASET.INVESTIGATION_ID == INVESTIGATION.ID', backref='DATASET') - SAMPLE = relationship('SAMPLE', primaryjoin='DATASET.SAMPLE_ID == SAMPLE.ID', backref='DATASET') - DATASETTYPE = relationship('DATASETTYPE', primaryjoin='DATASET.TYPE_ID == DATASETTYPE.ID', backref='DATASET') + SAMPLE = relationship( + 'SAMPLE', primaryjoin='DATASET.SAMPLE_ID == SAMPLE.ID', backref='DATASET') + DATASETTYPE = relationship( + 'DATASETTYPE', primaryjoin='DATASET.TYPE_ID == DATASETTYPE.ID', backref='DATASET') class DATASETPARAMETER(Base, EntityHelper): @@ -374,7 +392,8 @@ class DATASETPARAMETER(Base, EntityHelper): RANGETOP = Column(Float(asdecimal=True)) STRING_VALUE = Column(String(4000)) DATASET_ID = Column(ForeignKey('DATASET.ID'), nullable=False) - PARAMETER_TYPE_ID = Column(ForeignKey('PARAMETERTYPE.ID'), nullable=False, index=True) + PARAMETER_TYPE_ID = Column(ForeignKey( + 'PARAMETERTYPE.ID'), nullable=False, index=True) DATASET = relationship('DATASET', primaryjoin='DATASETPARAMETER.DATASET_ID == DATASET.ID', backref='DATASETPARAMETER') @@ -397,7 +416,8 @@ class DATASETTYPE(Base, EntityHelper): NAME = Column(String(255), nullable=False) FACILITY_ID = Column(ForeignKey('FACILITY.ID'), nullable=False) - FACILITY = relationship('FACILITY', primaryjoin='DATASETTYPE.FACILITY_ID == FACILITY.ID', backref='DATASETTYPE') + FACILITY = relationship( + 'FACILITY', primaryjoin='DATASETTYPE.FACILITY_ID == FACILITY.ID', backref='DATASETTYPE') class FACILITYCYCLE(Base, EntityHelper): @@ -450,7 +470,8 @@ class INSTRUMENT(Base, EntityHelper): URL = Column(String(255)) FACILITY_ID = Column(ForeignKey('FACILITY.ID'), nullable=False) - FACILITY = relationship('FACILITY', primaryjoin='INSTRUMENT.FACILITY_ID == FACILITY.ID', backref='INSTRUMENT') + FACILITY = relationship( + 'FACILITY', primaryjoin='INSTRUMENT.FACILITY_ID == FACILITY.ID', backref='INSTRUMENT') class INSTRUMENTSCIENTIST(Base, EntityHelper): @@ -464,12 +485,14 @@ class INSTRUMENTSCIENTIST(Base, EntityHelper): CREATE_TIME = Column(DateTime, nullable=False) MOD_ID = Column(String(255), nullable=False) MOD_TIME = Column(DateTime, nullable=False) - INSTRUMENT_ID = Column(ForeignKey('INSTRUMENT.ID'), nullable=False, index=True) + INSTRUMENT_ID = Column(ForeignKey('INSTRUMENT.ID'), + nullable=False, index=True) USER_ID = Column(ForeignKey('USER_.ID'), nullable=False) INSTRUMENT = relationship('INSTRUMENT', primaryjoin='INSTRUMENTSCIENTIST.INSTRUMENT_ID == INSTRUMENT.ID', backref='INSTRUMENTSCIENTIST') - USER_ = relationship('USER', primaryjoin='INSTRUMENTSCIENTIST.USER_ID == USER.ID', backref='INSTRUMENTSCIENTIST') + USER_ = relationship( + 'USER', primaryjoin='INSTRUMENTSCIENTIST.USER_ID == USER.ID', backref='INSTRUMENTSCIENTIST') class INVESTIGATION(Base, EntityHelper): @@ -492,7 +515,8 @@ class INVESTIGATION(Base, EntityHelper): TITLE = Column(String(255), nullable=False) VISIT_ID = Column(String(255), nullable=False) FACILITY_ID = Column(ForeignKey('FACILITY.ID'), nullable=False) - TYPE_ID = Column(ForeignKey('INVESTIGATIONTYPE.ID'), nullable=False, index=True) + TYPE_ID = Column(ForeignKey('INVESTIGATIONTYPE.ID'), + nullable=False, index=True) FACILITY = relationship('FACILITY', primaryjoin='INVESTIGATION.FACILITY_ID == FACILITY.ID', backref='INVESTIGATION') @@ -503,7 +527,8 @@ class INVESTIGATION(Base, EntityHelper): class INVESTIGATIONGROUP(Base, EntityHelper): __tablename__ = 'INVESTIGATIONGROUP' __table_args__ = ( - Index('UNQ_INVESTIGATIONGROUP_0', 'GROUP_ID', 'INVESTIGATION_ID', 'ROLE'), + Index('UNQ_INVESTIGATIONGROUP_0', 'GROUP_ID', + 'INVESTIGATION_ID', 'ROLE'), ) ID = Column(BigInteger, primary_key=True) @@ -513,7 +538,8 @@ class INVESTIGATIONGROUP(Base, EntityHelper): MOD_TIME = Column(DateTime, nullable=False) ROLE = Column(String(255), nullable=False) GROUP_ID = Column(ForeignKey('GROUPING.ID'), nullable=False) - INVESTIGATION_ID = Column(ForeignKey('INVESTIGATION.ID'), nullable=False, index=True) + INVESTIGATION_ID = Column(ForeignKey( + 'INVESTIGATION.ID'), nullable=False, index=True) GROUPING = relationship('GROUPING', primaryjoin='INVESTIGATIONGROUP.GROUP_ID == GROUPING.ID', backref='INVESTIGATIONGROUP') @@ -524,7 +550,8 @@ class INVESTIGATIONGROUP(Base, EntityHelper): class INVESTIGATIONINSTRUMENT(Base, EntityHelper): __tablename__ = 'INVESTIGATIONINSTRUMENT' __table_args__ = ( - Index('UNQ_INVESTIGATIONINSTRUMENT_0', 'INVESTIGATION_ID', 'INSTRUMENT_ID'), + Index('UNQ_INVESTIGATIONINSTRUMENT_0', + 'INVESTIGATION_ID', 'INSTRUMENT_ID'), ) ID = Column(BigInteger, primary_key=True) @@ -532,7 +559,8 @@ class INVESTIGATIONINSTRUMENT(Base, EntityHelper): CREATE_TIME = Column(DateTime, nullable=False) MOD_ID = Column(String(255), nullable=False) MOD_TIME = Column(DateTime, nullable=False) - INSTRUMENT_ID = Column(ForeignKey('INSTRUMENT.ID'), nullable=False, index=True) + INSTRUMENT_ID = Column(ForeignKey('INSTRUMENT.ID'), + nullable=False, index=True) INVESTIGATION_ID = Column(ForeignKey('INVESTIGATION.ID'), nullable=False) INSTRUMENT = relationship('INSTRUMENT', primaryjoin='INVESTIGATIONINSTRUMENT.INSTRUMENT_ID == INSTRUMENT.ID', @@ -545,7 +573,8 @@ class INVESTIGATIONINSTRUMENT(Base, EntityHelper): class INVESTIGATIONPARAMETER(Base, EntityHelper): __tablename__ = 'INVESTIGATIONPARAMETER' __table_args__ = ( - Index('UNQ_INVESTIGATIONPARAMETER_0', 'INVESTIGATION_ID', 'PARAMETER_TYPE_ID'), + Index('UNQ_INVESTIGATIONPARAMETER_0', + 'INVESTIGATION_ID', 'PARAMETER_TYPE_ID'), ) ID = Column(BigInteger, primary_key=True) @@ -560,7 +589,8 @@ class INVESTIGATIONPARAMETER(Base, EntityHelper): RANGETOP = Column(Float(asdecimal=True)) STRING_VALUE = Column(String(4000)) INVESTIGATION_ID = Column(ForeignKey('INVESTIGATION.ID'), nullable=False) - PARAMETER_TYPE_ID = Column(ForeignKey('PARAMETERTYPE.ID'), nullable=False, index=True) + PARAMETER_TYPE_ID = Column(ForeignKey( + 'PARAMETERTYPE.ID'), nullable=False, index=True) INVESTIGATION = relationship('INVESTIGATION', primaryjoin='INVESTIGATIONPARAMETER.INVESTIGATION_ID == INVESTIGATION.ID', @@ -601,12 +631,14 @@ class INVESTIGATIONUSER(Base, EntityHelper): MOD_ID = Column(String(255), nullable=False) MOD_TIME = Column(DateTime, nullable=False) ROLE = Column(String(255), nullable=False) - INVESTIGATION_ID = Column(ForeignKey('INVESTIGATION.ID'), nullable=False, index=True) + INVESTIGATION_ID = Column(ForeignKey( + 'INVESTIGATION.ID'), nullable=False, index=True) USER_ID = Column(ForeignKey('USER_.ID'), nullable=False) INVESTIGATION = relationship('INVESTIGATION', primaryjoin='INVESTIGATIONUSER.INVESTIGATION_ID == INVESTIGATION.ID', backref='INVESTIGATIONUSER') - USER_ = relationship('USER', primaryjoin='INVESTIGATIONUSER.USER_ID == USER.ID', backref='INVESTIGATIONUSER') + USER_ = relationship( + 'USER', primaryjoin='INVESTIGATIONUSER.USER_ID == USER.ID', backref='INVESTIGATIONUSER') class JOB(Base, EntityHelper): @@ -618,11 +650,15 @@ class JOB(Base, EntityHelper): CREATE_TIME = Column(DateTime, nullable=False) MOD_ID = Column(String(255), nullable=False) MOD_TIME = Column(DateTime, nullable=False) - APPLICATION_ID = Column(ForeignKey('APPLICATION.ID'), nullable=False, index=True) - INPUTDATACOLLECTION_ID = Column(ForeignKey('DATACOLLECTION.ID'), index=True) - OUTPUTDATACOLLECTION_ID = Column(ForeignKey('DATACOLLECTION.ID'), index=True) + APPLICATION_ID = Column(ForeignKey('APPLICATION.ID'), + nullable=False, index=True) + INPUTDATACOLLECTION_ID = Column( + ForeignKey('DATACOLLECTION.ID'), index=True) + OUTPUTDATACOLLECTION_ID = Column( + ForeignKey('DATACOLLECTION.ID'), index=True) - APPLICATION = relationship('APPLICATION', primaryjoin='JOB.APPLICATION_ID == APPLICATION.ID', backref='JOB') + APPLICATION = relationship( + 'APPLICATION', primaryjoin='JOB.APPLICATION_ID == APPLICATION.ID', backref='JOB') DATACOLLECTION = relationship('DATACOLLECTION', primaryjoin='JOB.INPUTDATACOLLECTION_ID == DATACOLLECTION.ID', backref='JOB') @@ -639,7 +675,8 @@ class KEYWORD(Base, EntityHelper): MOD_ID = Column(String(255), nullable=False) MOD_TIME = Column(DateTime, nullable=False) NAME = Column(String(255), nullable=False) - INVESTIGATION_ID = Column(ForeignKey('INVESTIGATION.ID'), nullable=False, index=True) + INVESTIGATION_ID = Column(ForeignKey( + 'INVESTIGATION.ID'), nullable=False, index=True) INVESTIGATION = relationship('INVESTIGATION', primaryjoin='KEYWORD.INVESTIGATION_ID == INVESTIGATION.ID', backref='KEYWORD') @@ -693,7 +730,8 @@ class PERMISSIBLESTRINGVALUE(Base, EntityHelper): MOD_ID = Column(String(255), nullable=False) MOD_TIME = Column(DateTime, nullable=False) VALUE = Column(String(255), nullable=False) - PARAMETERTYPE_ID = Column(ForeignKey('PARAMETERTYPE.ID'), nullable=False, index=True) + PARAMETERTYPE_ID = Column(ForeignKey( + 'PARAMETERTYPE.ID'), nullable=False, index=True) PARAMETERTYPE = relationship('PARAMETERTYPE', primaryjoin='PERMISSIBLESTRINGVALUE.PARAMETERTYPE_ID == PARAMETERTYPE.ID', @@ -713,7 +751,8 @@ class PUBLICATION(Base, EntityHelper): REPOSITORY = Column(String(255)) REPOSITORYID = Column(String(255)) URL = Column(String(255)) - INVESTIGATION_ID = Column(ForeignKey('INVESTIGATION.ID'), nullable=False, index=True) + INVESTIGATION_ID = Column(ForeignKey( + 'INVESTIGATION.ID'), nullable=False, index=True) INVESTIGATION = relationship('INVESTIGATION', primaryjoin='PUBLICATION.INVESTIGATION_ID == INVESTIGATION.ID', backref='PUBLICATION') @@ -746,7 +785,8 @@ class RELATEDDATAFILE(Base, EntityHelper): MOD_ID = Column(String(255), nullable=False) MOD_TIME = Column(DateTime, nullable=False) RELATION = Column(String(255), nullable=False) - DEST_DATAFILE_ID = Column(ForeignKey('DATAFILE.ID'), nullable=False, index=True) + DEST_DATAFILE_ID = Column(ForeignKey( + 'DATAFILE.ID'), nullable=False, index=True) SOURCE_DATAFILE_ID = Column(ForeignKey('DATAFILE.ID'), nullable=False) DATAFILE = relationship('DATAFILE', primaryjoin='RELATEDDATAFILE.DEST_DATAFILE_ID == DATAFILE.ID', @@ -775,7 +815,8 @@ class RULE(Base, EntityHelper): WHAT = Column(String(1024), nullable=False) GROUPING_ID = Column(ForeignKey('GROUPING.ID'), index=True) - GROUPING = relationship('GROUPING', primaryjoin='RULE.GROUPING_ID == GROUPING.ID', backref='RULE') + GROUPING = relationship( + 'GROUPING', primaryjoin='RULE.GROUPING_ID == GROUPING.ID', backref='RULE') class SAMPLE(Base, EntityHelper): @@ -795,7 +836,8 @@ class SAMPLE(Base, EntityHelper): INVESTIGATION = relationship('INVESTIGATION', primaryjoin='SAMPLE.INVESTIGATION_ID == INVESTIGATION.ID', backref='SAMPLE') - SAMPLETYPE = relationship('SAMPLETYPE', primaryjoin='SAMPLE.SAMPLETYPE_ID == SAMPLETYPE.ID', backref='SAMPLE') + SAMPLETYPE = relationship( + 'SAMPLETYPE', primaryjoin='SAMPLE.SAMPLETYPE_ID == SAMPLETYPE.ID', backref='SAMPLE') class SAMPLEPARAMETER(Base, EntityHelper): @@ -816,11 +858,13 @@ class SAMPLEPARAMETER(Base, EntityHelper): RANGETOP = Column(Float(asdecimal=True)) STRING_VALUE = Column(String(4000)) SAMPLE_ID = Column(ForeignKey('SAMPLE.ID'), nullable=False) - PARAMETER_TYPE_ID = Column(ForeignKey('PARAMETERTYPE.ID'), nullable=False, index=True) + PARAMETER_TYPE_ID = Column(ForeignKey( + 'PARAMETERTYPE.ID'), nullable=False, index=True) PARAMETERTYPE = relationship('PARAMETERTYPE', primaryjoin='SAMPLEPARAMETER.PARAMETER_TYPE_ID == PARAMETERTYPE.ID', backref='SAMPLEPARAMETER') - SAMPLE = relationship('SAMPLE', primaryjoin='SAMPLEPARAMETER.SAMPLE_ID == SAMPLE.ID', backref='SAMPLEPARAMETER') + SAMPLE = relationship( + 'SAMPLE', primaryjoin='SAMPLEPARAMETER.SAMPLE_ID == SAMPLE.ID', backref='SAMPLEPARAMETER') class SESSION(Base, EntityHelper): @@ -879,8 +923,10 @@ class USERGROUP(Base, EntityHelper): GROUP_ID = Column(ForeignKey('GROUPING.ID'), nullable=False, index=True) USER_ID = Column(ForeignKey('USER_.ID'), nullable=False) - GROUPING = relationship('GROUPING', primaryjoin='USERGROUP.GROUP_ID == GROUPING.ID', backref='USERGROUP') - USER_ = relationship('USER', primaryjoin='USERGROUP.USER_ID == USER.ID', backref='USERGROUP') + GROUPING = relationship( + 'GROUPING', primaryjoin='USERGROUP.GROUP_ID == GROUPING.ID', backref='USERGROUP') + USER_ = relationship( + 'USER', primaryjoin='USERGROUP.USER_ID == USER.ID', backref='USERGROUP') class STUDYINVESTIGATION(Base, EntityHelper): @@ -894,12 +940,14 @@ class STUDYINVESTIGATION(Base, EntityHelper): CREATE_TIME = Column(DateTime, nullable=False) MOD_ID = Column(String(255), nullable=False) MOD_TIME = Column(DateTime, nullable=False) - INVESTIGATION_ID = Column(ForeignKey('INVESTIGATION.ID'), nullable=False, index=True) + INVESTIGATION_ID = Column(ForeignKey( + 'INVESTIGATION.ID'), nullable=False, index=True) STUDY_ID = Column(ForeignKey('STUDY.ID'), nullable=False) INVESTIGATION = relationship('INVESTIGATION', primaryjoin='STUDYINVESTIGATION.INVESTIGATION_ID == INVESTIGATION.ID', backref='STUDYINVESTIGATION') - STUDY = relationship('STUDY', primaryjoin='STUDYINVESTIGATION.STUDY_ID == STUDY.ID', backref='STUDYINVESTIGATION') + STUDY = relationship( + 'STUDY', primaryjoin='STUDYINVESTIGATION.STUDY_ID == STUDY.ID', backref='STUDYINVESTIGATION') class STUDY(Base, EntityHelper): @@ -916,7 +964,8 @@ class STUDY(Base, EntityHelper): STATUS = Column(Integer) USER_ID = Column(ForeignKey('USER_.ID'), index=True) - USER_ = relationship('USER', primaryjoin='STUDY.USER_ID == USER.ID', backref='STUDY') + USER_ = relationship( + 'USER', primaryjoin='STUDY.USER_ID == USER.ID', backref='STUDY') class SAMPLETYPE(Base, EntityHelper): @@ -935,4 +984,5 @@ class SAMPLETYPE(Base, EntityHelper): SAFETYINFORMATION = Column(String(4000)) FACILITY_ID = Column(ForeignKey('FACILITY.ID'), nullable=False) - FACILITY = relationship('FACILITY', primaryjoin='SAMPLETYPE.FACILITY_ID == FACILITY.ID', backref='SAMPLETYPE') + FACILITY = relationship( + 'FACILITY', primaryjoin='SAMPLETYPE.FACILITY_ID == FACILITY.ID', backref='SAMPLETYPE') diff --git a/config.json.example b/config.json.example index 6ff8e463..391d78d3 100644 --- a/config.json.example +++ b/config.json.example @@ -1,4 +1,5 @@ { + "backend": "db", "DB_URL": "mysql+pymysql://root:rootpw@localhost:13306/icatdb", "log_level": "WARN", "debug_mode": false, diff --git a/src/main.py b/src/main.py index 390bfb0b..69be1555 100644 --- a/src/main.py +++ b/src/main.py @@ -12,6 +12,7 @@ InstrumentsFacilityCycles, InstrumentsFacilityCyclesCount, InstrumentsFacilityCyclesInvestigations, \ InstrumentsFacilityCyclesInvestigationsCount from src.swagger.swagger_generator import swagger_gen +from common.exceptions import ApiError swagger_gen.write_swagger_spec() @@ -20,26 +21,41 @@ app.url_map.strict_slashes = False api = Api(app) + +@app.errorhandler(ApiError) +def handle_error(e): + return str(e), e.status_code + + setup_logger() for entity_name in endpoints: - api.add_resource(get_endpoint(entity_name, endpoints[entity_name]), f"/{entity_name.lower()}") - api.add_resource(get_id_endpoint(entity_name, endpoints[entity_name]), f"/{entity_name.lower()}/") - api.add_resource(get_count_endpoint(entity_name, endpoints[entity_name]), f"/{entity_name.lower()}/count") - api.add_resource(get_find_one_endpoint(entity_name, endpoints[entity_name]), f"/{entity_name.lower()}/findone") + api.add_resource(get_endpoint( + entity_name, endpoints[entity_name]), f"/{entity_name.lower()}") + api.add_resource(get_id_endpoint( + entity_name, endpoints[entity_name]), f"/{entity_name.lower()}/") + api.add_resource(get_count_endpoint( + entity_name, endpoints[entity_name]), f"/{entity_name.lower()}/count") + api.add_resource(get_find_one_endpoint( + entity_name, endpoints[entity_name]), f"/{entity_name.lower()}/findone") # Session endpoint api.add_resource(Sessions, "/sessions") # Table specific endpoints api.add_resource(UsersInvestigations, "/users//investigations") -api.add_resource(UsersInvestigationsCount, "/users//investigations/count") -api.add_resource(InstrumentsFacilityCycles, "/instruments//facilitycycles") -api.add_resource(InstrumentsFacilityCyclesCount, "/instruments//facilitycycles/count") +api.add_resource(UsersInvestigationsCount, + "/users//investigations/count") +api.add_resource(InstrumentsFacilityCycles, + "/instruments//facilitycycles") +api.add_resource(InstrumentsFacilityCyclesCount, + "/instruments//facilitycycles/count") api.add_resource(InstrumentsFacilityCyclesInvestigations, "/instruments//facilitycycles//investigations") api.add_resource(InstrumentsFacilityCyclesInvestigationsCount, "/instruments//facilitycycles//investigations/count") + if __name__ == "__main__": - app.run(host=config.get_host(), port=config.get_port(), debug=config.is_debug_mode()) + app.run(host=config.get_host(), port=config.get_port(), + debug=config.is_debug_mode()) diff --git a/src/resources/entities/entity_endpoint.py b/src/resources/entities/entity_endpoint.py index a7f5fee5..2321e840 100644 --- a/src/resources/entities/entity_endpoint.py +++ b/src/resources/entities/entity_endpoint.py @@ -3,7 +3,8 @@ from common.database_helpers import get_rows_by_filter, create_rows_from_json, patch_entities, get_row_by_id, \ delete_row_by_id, update_row_from_id, get_filtered_row_count, get_first_filtered_row -from common.helpers import requires_session_id, queries_records, get_filters_from_query_string +from common.helpers import get_session_id_from_auth_header, get_filters_from_query_string +from common.backends import backend def get_endpoint(name, table): @@ -16,20 +17,14 @@ def get_endpoint(name, table): :return: The generated endpoint class """ class Endpoint(Resource): - @requires_session_id - @queries_records def get(self): - return get_rows_by_filter(table, get_filters_from_query_string()), 200 + return backend.get_with_filters(get_session_id_from_auth_header(), table, get_filters_from_query_string()), 200 - @requires_session_id - @queries_records def post(self): - return create_rows_from_json(table, request.json), 200 + return backend.create(get_session_id_from_auth_header(), table, request.json), 200 - @requires_session_id - @queries_records def patch(self): - return list(map(lambda x: x.to_dict(), patch_entities(table, request.json))), 200 + return list(map(lambda x: x.to_dict(), backend.update(get_session_id_from_auth_header(), table, request.json))), 200 Endpoint.__name__ = name return Endpoint @@ -46,22 +41,18 @@ def get_id_endpoint(name, table): """ class EndpointWithID(Resource): - @requires_session_id - @queries_records def get(self, id): - return get_row_by_id(table, id).to_dict(), 200 + return backend.get_with_id(get_session_id_from_auth_header(), table, id).to_dict(), 200 - @requires_session_id - @queries_records def delete(self, id): - delete_row_by_id(table, id) + backend.delete_with_id( + get_session_id_from_auth_header(), table, id) return "", 204 - @requires_session_id - @queries_records def patch(self, id): - update_row_from_id(table, id, request.json) - return get_row_by_id(table, id).to_dict(), 200 + session_id = get_session_id_from_auth_header() + backend.update_with_id(session_id, table, id, request.json) + return backend.get_with_id(session_id, table, id).to_dict(), 200 EndpointWithID.__name__ = f"{name}WithID" return EndpointWithID @@ -78,11 +69,9 @@ def get_count_endpoint(name, table): """ class CountEndpoint(Resource): - @requires_session_id - @queries_records def get(self): filters = get_filters_from_query_string() - return get_filtered_row_count(table, filters), 200 + return backend.count_with_filters(get_session_id_from_auth_header(), table, filters), 200 CountEndpoint.__name__ = f"{name}Count" return CountEndpoint @@ -99,11 +88,9 @@ def get_find_one_endpoint(name, table): """ class FindOneEndpoint(Resource): - @requires_session_id - @queries_records def get(self): filters = get_filters_from_query_string() - return get_first_filtered_row(table, filters), 200 + return backend.get_one_with_filters(get_session_id_from_auth_header(), table, filters), 200 FindOneEndpoint.__name__ = f"{name}FindOne" return FindOneEndpoint diff --git a/src/resources/non_entities/sessions_endpoints.py b/src/resources/non_entities/sessions_endpoints.py index 157258b9..04140925 100644 --- a/src/resources/non_entities/sessions_endpoints.py +++ b/src/resources/non_entities/sessions_endpoints.py @@ -4,8 +4,10 @@ from flask_restful import Resource, reqparse from common.database_helpers import insert_row_into_table, delete_row_by_id, get_row_by_id -from common.helpers import get_session_id_from_auth_header, requires_session_id, queries_records +from common.helpers import get_session_id_from_auth_header from common.models.db_models import SESSION +from common.backends import backend +from common.exceptions import AuthenticationError class Sessions(Resource): @@ -17,34 +19,29 @@ def post(self): """ if not (request.data and "username" in request.json and "password" in request.json): return "Bad request", 400 - if request.json["username"] == "user" and request.json["password"] == "password": - session_id = str(uuid.uuid1()) - insert_row_into_table(SESSION, SESSION(ID=session_id)) - return {"sessionID": session_id}, 201 - return "Forbidden", 403 - - @requires_session_id - @queries_records + try: + return {"sessionID": backend.login(request.json)}, 201 + except AuthenticationError: + return "Forbidden", 403 + def delete(self): """ Deletes a users sessionID when they logout :return: Blank response, 200 """ - delete_row_by_id(SESSION, get_session_id_from_auth_header()) + backend.logout(get_session_id_from_auth_header()) return "", 200 - @requires_session_id def get(self): """ Gives details of a users session :return: String: Details of the session, 200 """ - return get_row_by_id(SESSION, get_session_id_from_auth_header()).to_dict(), 200 + return backend.get_session_details(get_session_id_from_auth_header()).to_dict(), 200 - @requires_session_id def put(self): """ Refreshes a users session :return: String: The session ID that has been refreshed, 200 """ - return get_session_id_from_auth_header(), 200 + return backend.refresh(get_session_id_from_auth_header()), 200 diff --git a/src/resources/table_endpoints/table_endpoints.py b/src/resources/table_endpoints/table_endpoints.py index bd289a4c..f90f7491 100644 --- a/src/resources/table_endpoints/table_endpoints.py +++ b/src/resources/table_endpoints/table_endpoints.py @@ -3,48 +3,37 @@ from common.database_helpers import get_investigations_for_user, get_investigations_for_user_count, \ get_facility_cycles_for_instrument, get_facility_cycles_for_instrument_count, \ get_investigations_for_instrument_in_facility_cycle, get_investigations_for_instrument_in_facility_cycle_count -from common.helpers import requires_session_id, queries_records, get_filters_from_query_string +from common.helpers import get_session_id_from_auth_header, get_filters_from_query_string +from common.backends import backend class UsersInvestigations(Resource): - @requires_session_id - @queries_records def get(self, id): - return get_investigations_for_user(id, get_filters_from_query_string()), 200 + return backend.get_investigations_for_user(get_session_id_from_auth_header(), id, get_filters_from_query_string()), 200 class UsersInvestigationsCount(Resource): - @requires_session_id - @queries_records def get(self, id): - return get_investigations_for_user_count(id, get_filters_from_query_string()), 200 + return backend.get_investigations_for_user_count(get_session_id_from_auth_header(), id, get_filters_from_query_string()), 200 class InstrumentsFacilityCycles(Resource): - @requires_session_id - @queries_records def get(self, id): - return get_facility_cycles_for_instrument(id, get_filters_from_query_string()), 200 + return backend.get_facility_cycles_for_instrument(get_session_id_from_auth_header(), id, get_filters_from_query_string()), 200 class InstrumentsFacilityCyclesCount(Resource): - @requires_session_id - @queries_records def get(self, id): - return get_facility_cycles_for_instrument_count(id, get_filters_from_query_string()), 200 + return backend.get_facility_cycles_for_instrument_count(get_session_id_from_auth_header(), id, get_filters_from_query_string()), 200 class InstrumentsFacilityCyclesInvestigations(Resource): - @requires_session_id - @queries_records def get(self, instrument_id, cycle_id): - return get_investigations_for_instrument_in_facility_cycle(instrument_id, cycle_id, - get_filters_from_query_string()), 200 + return backend.get_investigations_for_instrument_in_facility_cycle(get_session_id_from_auth_header(), instrument_id, cycle_id, + get_filters_from_query_string()), 200 class InstrumentsFacilityCyclesInvestigationsCount(Resource): - @requires_session_id - @queries_records def get(self, instrument_id, cycle_id): - return get_investigations_for_instrument_in_facility_cycle_count(instrument_id, cycle_id, - get_filters_from_query_string()), 200 + return backend.get_investigations_for_instrument_in_facility_cycle_count(get_session_id_from_auth_header(), instrument_id, cycle_id, + get_filters_from_query_string()), 200 diff --git a/test/test_base.py b/test/test_base.py index 61375e9e..9c8f1d9b 100644 --- a/test/test_base.py +++ b/test/test_base.py @@ -9,5 +9,5 @@ class FlaskAppTest(TestCase): """ def setUp(self): + app.config["TESTING"] = True self.app = app.test_client() - self.app.testing = True diff --git a/test/test_helpers.py b/test/test_helpers.py index d5a8b53b..ca80df90 100644 --- a/test/test_helpers.py +++ b/test/test_helpers.py @@ -42,7 +42,8 @@ class TestRequires_session_id(FlaskAppTest): def setUp(self): super().setUp() self.good_credentials_header = {"Authorization": "Bearer Test"} - self.bad_credentials_header = {"Authorization": "Test"} + self.invalid_credentials_header = {"Authorization": "Test"} + self.bad_credentials_header = {"Authorization": "Bearer BadTest"} session = SESSION() session.ID = "Test" insert_row_into_table(SESSION, session) @@ -53,11 +54,17 @@ def tearDown(self): def test_missing_credentials(self): self.assertEqual(401, self.app.get("/datafiles").status_code) + def test_invalid_credentials(self): + self.assertEqual(403, self.app.get( + "/datafiles", headers=self.invalid_credentials_header).status_code) + def test_bad_credentials(self): - self.assertEqual(403, self.app.get("/datafiles", headers=self.bad_credentials_header).status_code) + self.assertEqual(403, self.app.get( + "/datafiles", headers=self.bad_credentials_header).status_code) def test_good_credentials(self): - self.assertEqual(200, self.app.get("/datafiles?limit=0", headers=self.good_credentials_header).status_code) + self.assertEqual(200, self.app.get("/datafiles?limit=0", + headers=self.good_credentials_header).status_code) class TestQueries_records(TestCase): @@ -66,42 +73,65 @@ def test_missing_record_error(self): def raise_missing_record(): raise MissingRecordError() - self.assertEqual(("No such record in table", 404), raise_missing_record()) + with self.assertRaises(MissingRecordError) as ctx: + raise_missing_record() + self.assertEqual("No such record in table", str(ctx.exception)) + self.assertEqual(404, ctx.exception.status_code) def test_bad_filter_error(self): @queries_records def raise_bad_filter_error(): raise BadFilterError() - self.assertEqual(("Invalid filter requested", 400), raise_bad_filter_error()) + with self.assertRaises(BadFilterError) as ctx: + raise_bad_filter_error() + + self.assertEqual("Invalid filter requested", str(ctx.exception)) + self.assertEqual(400, ctx.exception.status_code) def test_value_error(self): @queries_records def raise_value_error(): raise ValueError() - self.assertEqual(("Bad request", 400), raise_value_error()) + with self.assertRaises(BadRequestError) as ctx: + raise_value_error() + + self.assertEqual("Bad request", str(ctx.exception)) + self.assertEqual(400, ctx.exception.status_code) def test_type_error(self): @queries_records def raise_type_error(): raise TypeError() - self.assertEqual(("Bad request", 400), raise_type_error()) + with self.assertRaises(BadRequestError) as ctx: + raise_type_error() + + self.assertEqual("Bad request", str(ctx.exception)) + self.assertEqual(400, ctx.exception.status_code) def test_integrity_error(self): @queries_records def raise_integrity_error(): raise IntegrityError() - self.assertEqual(("Bad request", 400), raise_integrity_error()) + with self.assertRaises(BadRequestError) as ctx: + raise_integrity_error() + + self.assertEqual("Bad request", str(ctx.exception)) + self.assertEqual(400, ctx.exception.status_code) def test_bad_request_error(self): @queries_records def raise_bad_request_error(): raise BadRequestError() - self.assertEqual(("Bad request", 400), raise_bad_request_error()) + with self.assertRaises(BadRequestError) as ctx: + raise_bad_request_error() + + self.assertEqual("Bad request", str(ctx.exception)) + self.assertEqual(400, ctx.exception.status_code) class TestGet_session_id_from_auth_header(FlaskAppTest): @@ -109,12 +139,14 @@ class TestGet_session_id_from_auth_header(FlaskAppTest): def test_no_session_in_header(self): with self.app: self.app.get("/") - self.assertRaises(MissingCredentialsError, get_session_id_from_auth_header) + self.assertRaises(MissingCredentialsError, + get_session_id_from_auth_header) def test_with_bad_header(self): with self.app: self.app.get("/", headers={"Authorization": "test"}) - self.assertRaises(AuthenticationError, get_session_id_from_auth_header) + self.assertRaises(AuthenticationError, + get_session_id_from_auth_header) def test_with_good_header(self): with self.app: @@ -137,47 +169,58 @@ def test_limit_filter(self): with self.app: self.app.get("/?limit=10") filters = get_filters_from_query_string() - self.assertEqual(1, len(filters), msg="Returned incorrect number of filters") - self.assertIs(LimitFilter, type(filters[0]), msg="Incorrect type of filter") + self.assertEqual( + 1, len(filters), msg="Returned incorrect number of filters") + self.assertIs(LimitFilter, type( + filters[0]), msg="Incorrect type of filter") def test_order_filter(self): with self.app: self.app.get("/?order=\"ID DESC\"") filters = get_filters_from_query_string() - self.assertEqual(1, len(filters), msg="Returned incorrect number of filters") - self.assertIs(OrderFilter, type(filters[0]), msg="Incorrect type of filter returned") + self.assertEqual( + 1, len(filters), msg="Returned incorrect number of filters") + self.assertIs(OrderFilter, type( + filters[0]), msg="Incorrect type of filter returned") def test_where_filter(self): with self.app: self.app.get('/?where={"ID":{"eq":3}}') filters = get_filters_from_query_string() - self.assertEqual(1, len(filters), msg="Returned incorrect number of filters") - self.assertIs(WhereFilter, type(filters[0]), msg="Incorrect type of filter returned") + self.assertEqual( + 1, len(filters), msg="Returned incorrect number of filters") + self.assertIs(WhereFilter, type( + filters[0]), msg="Incorrect type of filter returned") def test_skip_filter(self): with self.app: self.app.get('/?skip=10') filters = get_filters_from_query_string() - self.assertEqual(1, len(filters),msg="Returned incorrect number of filters") - self.assertIs(SkipFilter, type(filters[0]), msg="Incorrect type of filter returned") - + self.assertEqual( + 1, len(filters), msg="Returned incorrect number of filters") + self.assertIs(SkipFilter, type( + filters[0]), msg="Incorrect type of filter returned") def test_include_filter(self): with self.app: self.app.get("/?include=\"TEST\"") - filters =get_filters_from_query_string() - self.assertEqual(1, len(filters), msg="Incorrect number of filters returned") - self.assertIs(IncludeFilter, type(filters[0]), msg="Incorrect type of filter returned") + filters = get_filters_from_query_string() + self.assertEqual( + 1, len(filters), msg="Incorrect number of filters returned") + self.assertIs(IncludeFilter, type( + filters[0]), msg="Incorrect type of filter returned") def test_distinct_filter(self): with self.app: self.app.get("/?distinct=\"ID\"") filters = get_filters_from_query_string() - self.assertEqual(1, len(filters), msg="Incorrect number of filters returned") - self.assertIs(DistinctFieldFilter, type(filters[0]), msg="Incorrect type of filter returned") + self.assertEqual( + 1, len(filters), msg="Incorrect number of filters returned") + self.assertIs(DistinctFieldFilter, type( + filters[0]), msg="Incorrect type of filter returned") def test_multiple_filters(self): with self.app: self.app.get("/?limit=10&skip=4") filters = get_filters_from_query_string() - self.assertEqual(2, len(filters)) \ No newline at end of file + self.assertEqual(2, len(filters))