diff --git a/common/database_helpers.py b/common/database_helpers.py index be1e745e..d9645fb9 100644 --- a/common/database_helpers.py +++ b/common/database_helpers.py @@ -55,6 +55,7 @@ class ReadQuery(Query): def __init__(self, table): super().__init__(table) self.include_related_entities = False + self.is_distinct_fields_query = False def commit_changes(self): log.info("Closing DB session") @@ -124,13 +125,18 @@ def execute_query(self): class QueryFilter(ABC): + @property + @abstractmethod + def precedence(self): + pass + @abstractmethod def apply_filter(self, query): pass class WhereFilter(QueryFilter): - precedence = 0 + precedence = 1 def __init__(self, field, value, operation): self.field = field @@ -151,8 +157,23 @@ def apply_filter(self, query): raise BadFilterError(f" Bad operation given to where filter. operation: {self.operation}") +class DistinctFieldFilter(QueryFilter): + precedence = 0 + + def __init__(self, fields): + self.fields = fields if type(fields) is list else [fields] # This allows single string distinct filters + + def apply_filter(self, query): + query.is_distinct_fields_query = True + try: + self.fields = [getattr(query.table, field) for field in self.fields] + except AttributeError: + raise BadFilterError("Bad field requested") + query.base_query = query.session.query(*self.fields).distinct() + + class OrderFilter(QueryFilter): - precedence = 1 + precedence = 2 def __init__(self, field, direction): self.field = field @@ -168,7 +189,7 @@ def apply_filter(self, query): class SkipFilter(QueryFilter): - precedence = 2 + precedence = 3 def __init__(self, skip_value): self.skip_value = skip_value @@ -178,7 +199,7 @@ def apply_filter(self, query): class LimitFilter(QueryFilter): - precedence = 3 + precedence = 4 def __init__(self, limit_value): self.limit_value = limit_value @@ -188,7 +209,7 @@ def apply_filter(self, query): class IncludeFilter(QueryFilter): - precedence = 4 + precedence = 5 def __init__(self, included_filters): self.included_filters = included_filters @@ -221,6 +242,8 @@ def get_query_filter(filter): return LimitFilter(filter["limit"]) elif filter_name == "include": return IncludeFilter(filter) + elif filter_name == "distinct": + return DistinctFieldFilter(filter["distinct"]) else: raise BadFilterError(f" Bad filter: {filter}") @@ -330,16 +353,43 @@ def get_filtered_read_query_results(filter_handler, filters, query): filter_handler.add_filter(QueryFilterFactory.get_query_filter(query_filter)) filter_handler.apply_filters(query) results = query.get_all_results() + if query.is_distinct_fields_query: + return _get_distinct_fields_as_dicts(results) if query.include_related_entities: - for query_filter in filters: - if list(query_filter)[0].lower() == "include": - return list(map(lambda x: x.to_nested_dict(query_filter["include"]), results)) + return _get_results_with_include(filters, results) return list(map(lambda x: x.to_dict(), results)) finally: query.session.close() +def _get_results_with_include(filters, results): + """ + Given a list of entities and a list of filters, use the include filter to nest the included entities requested in + the include filter given + :param filters: The list of filters + :param results: The list of entities + :return: A list of nested dictionaries representing the entity results + """ + for query_filter in filters: + if list(query_filter)[0].lower() == "include": + return [x.to_nested_dict(query_filter["include"]) for x in results] + + +def _get_distinct_fields_as_dicts(results): + """ + Given a list of column results return a list of dictionaries where each column name is the key and the column value + is the dictionary key value + :param results: A list of sql alchemy result objects + :return: A list of dictionary representations of the sqlalchemy result objects + """ + dictionaries = [] + for result in results: + dictionary = {k: getattr(result, k) for k in result.keys()} + dictionaries.append(dictionary) + return dictionaries + + def get_rows_by_filter(table, filters): """ Given a list of filters supplied in json format, returns entities that match the filters from the given table @@ -465,7 +515,7 @@ class InstrumentFacilityCyclesQuery(ReadQuery): def __init__(self, instrument_id): super().__init__(FACILITYCYCLE) investigationInstrument = aliased(INSTRUMENT) - self.base_query = self.base_query\ + self.base_query = self.base_query \ .join(FACILITYCYCLE.FACILITY) \ .join(FACILITY.INSTRUMENT) \ .join(FACILITY.INVESTIGATION) \ @@ -504,7 +554,7 @@ class InstrumentFacilityCycleInvestigationsQuery(ReadQuery): def __init__(self, instrument_id, facility_cycle_id): super().__init__(INVESTIGATION) investigationInstrument = aliased(INSTRUMENT) - self.base_query = self.base_query\ + self.base_query = self.base_query \ .join(INVESTIGATION.FACILITY) \ .join(FACILITY.FACILITYCYCLE) \ .join(FACILITY.INSTRUMENT) \