Skip to content

Commit

Permalink
Merge pull request #56 from ral-facilities/41_add_ability_to_get_dist…
Browse files Browse the repository at this point in the history
…inct_values

Add ability to get distinct values
  • Loading branch information
keiranjprice101 authored Sep 12, 2019
2 parents bb0a2a3 + 12adcab commit 5695e3f
Showing 1 changed file with 60 additions and 10 deletions.
70 changes: 60 additions & 10 deletions common/database_helpers.py
Original file line number Diff line number Diff line change
Expand Up @@ -55,6 +55,7 @@ class ReadQuery(Query):
def __init__(self, table):
super().__init__(table)
self.include_related_entities = False
self.is_distinct_fields_query = False

def commit_changes(self):
log.info("Closing DB session")
Expand Down Expand Up @@ -124,13 +125,18 @@ def execute_query(self):


class QueryFilter(ABC):
@property
@abstractmethod
def precedence(self):
pass

@abstractmethod
def apply_filter(self, query):
pass


class WhereFilter(QueryFilter):
precedence = 0
precedence = 1

def __init__(self, field, value, operation):
self.field = field
Expand All @@ -151,8 +157,23 @@ def apply_filter(self, query):
raise BadFilterError(f" Bad operation given to where filter. operation: {self.operation}")


class DistinctFieldFilter(QueryFilter):
precedence = 0

def __init__(self, fields):
self.fields = fields if type(fields) is list else [fields] # This allows single string distinct filters

def apply_filter(self, query):
query.is_distinct_fields_query = True
try:
self.fields = [getattr(query.table, field) for field in self.fields]
except AttributeError:
raise BadFilterError("Bad field requested")
query.base_query = query.session.query(*self.fields).distinct()


class OrderFilter(QueryFilter):
precedence = 1
precedence = 2

def __init__(self, field, direction):
self.field = field
Expand All @@ -168,7 +189,7 @@ def apply_filter(self, query):


class SkipFilter(QueryFilter):
precedence = 2
precedence = 3

def __init__(self, skip_value):
self.skip_value = skip_value
Expand All @@ -178,7 +199,7 @@ def apply_filter(self, query):


class LimitFilter(QueryFilter):
precedence = 3
precedence = 4

def __init__(self, limit_value):
self.limit_value = limit_value
Expand All @@ -188,7 +209,7 @@ def apply_filter(self, query):


class IncludeFilter(QueryFilter):
precedence = 4
precedence = 5

def __init__(self, included_filters):
self.included_filters = included_filters
Expand Down Expand Up @@ -221,6 +242,8 @@ def get_query_filter(filter):
return LimitFilter(filter["limit"])
elif filter_name == "include":
return IncludeFilter(filter)
elif filter_name == "distinct":
return DistinctFieldFilter(filter["distinct"])
else:
raise BadFilterError(f" Bad filter: {filter}")

Expand Down Expand Up @@ -330,16 +353,43 @@ def get_filtered_read_query_results(filter_handler, filters, query):
filter_handler.add_filter(QueryFilterFactory.get_query_filter(query_filter))
filter_handler.apply_filters(query)
results = query.get_all_results()
if query.is_distinct_fields_query:
return _get_distinct_fields_as_dicts(results)
if query.include_related_entities:
for query_filter in filters:
if list(query_filter)[0].lower() == "include":
return list(map(lambda x: x.to_nested_dict(query_filter["include"]), results))
return _get_results_with_include(filters, results)
return list(map(lambda x: x.to_dict(), results))

finally:
query.session.close()


def _get_results_with_include(filters, results):
"""
Given a list of entities and a list of filters, use the include filter to nest the included entities requested in
the include filter given
:param filters: The list of filters
:param results: The list of entities
:return: A list of nested dictionaries representing the entity results
"""
for query_filter in filters:
if list(query_filter)[0].lower() == "include":
return [x.to_nested_dict(query_filter["include"]) for x in results]


def _get_distinct_fields_as_dicts(results):
"""
Given a list of column results return a list of dictionaries where each column name is the key and the column value
is the dictionary key value
:param results: A list of sql alchemy result objects
:return: A list of dictionary representations of the sqlalchemy result objects
"""
dictionaries = []
for result in results:
dictionary = {k: getattr(result, k) for k in result.keys()}
dictionaries.append(dictionary)
return dictionaries


def get_rows_by_filter(table, filters):
"""
Given a list of filters supplied in json format, returns entities that match the filters from the given table
Expand Down Expand Up @@ -465,7 +515,7 @@ class InstrumentFacilityCyclesQuery(ReadQuery):
def __init__(self, instrument_id):
super().__init__(FACILITYCYCLE)
investigationInstrument = aliased(INSTRUMENT)
self.base_query = self.base_query\
self.base_query = self.base_query \
.join(FACILITYCYCLE.FACILITY) \
.join(FACILITY.INSTRUMENT) \
.join(FACILITY.INVESTIGATION) \
Expand Down Expand Up @@ -504,7 +554,7 @@ class InstrumentFacilityCycleInvestigationsQuery(ReadQuery):
def __init__(self, instrument_id, facility_cycle_id):
super().__init__(INVESTIGATION)
investigationInstrument = aliased(INSTRUMENT)
self.base_query = self.base_query\
self.base_query = self.base_query \
.join(INVESTIGATION.FACILITY) \
.join(FACILITY.FACILITYCYCLE) \
.join(FACILITY.INSTRUMENT) \
Expand Down

0 comments on commit 5695e3f

Please sign in to comment.