Skip to content

Commit

Permalink
Merge pull request #35 from ral-facilities/30_use_order_of_operations
Browse files Browse the repository at this point in the history
Use order of operations for filtering
  • Loading branch information
keiranjprice101 authored Aug 13, 2019
2 parents c6f0b87 + 38b9ec0 commit 2f59010
Showing 1 changed file with 44 additions and 8 deletions.
52 changes: 44 additions & 8 deletions common/database_helpers.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,6 @@ def __init__(self, table):
self.session = session_manager.get_icat_db_session()
self.table = table
self.base_query = self.session.query(table)
self.is_limited = False

@abstractmethod
def execute_query(self):
Expand Down Expand Up @@ -123,6 +122,8 @@ def apply_filter(self, query):


class WhereFilter(QueryFilter):
precedence = 0

def __init__(self, field, value):
self.field = field
self.value = value
Expand All @@ -132,13 +133,13 @@ def apply_filter(self, query):


class OrderFilter(QueryFilter):
precedence = 1

def __init__(self, field, direction):
self.field = field
self.direction = direction

def apply_filter(self, query):
if query.is_limited:
query.base_query = query.base_query.from_self()
if self.direction.upper() == "ASC":
query.base_query = query.base_query.order_by(asc(self.field.upper()))
elif self.direction.upper() == "DESC":
Expand All @@ -148,6 +149,8 @@ def apply_filter(self, query):


class SkipFilter(QueryFilter):
precedence = 2

def __init__(self, skip_value):
self.skip_value = skip_value

Expand All @@ -156,15 +159,18 @@ def apply_filter(self, query):


class LimitFilter(QueryFilter):
precedence = 3

def __init__(self, limit_value):
self.limit_value = limit_value

def apply_filter(self, query):
query.base_query = query.base_query.limit(self.limit_value)
query.is_limited = True


class IncludeFilter(QueryFilter):
precedence = 4

def __init__(self, included_filters):
self.included_filters = included_filters

Expand Down Expand Up @@ -195,6 +201,32 @@ def get_query_filter(filter):
raise BadFilterError(f" Bad filter: {filter}")


class FilterOrderHandler(object):
"""
The FilterOrderHandler takes in filters, sorts them according to the order of operations, then applies them.
"""
def __init__(self):
self.filters = []

def add_filter(self, filter):
self.filters.append(filter)

def sort_filters(self):
"""
Sorts the filters according to the order of operations
"""
self.filters.sort(key=lambda x: x.precedence)

def apply_filters(self, query):
"""
Given a query apply the filters the handler has in the correct order.
:param query: The query to have filters applied to
"""
self.sort_filters()
for filter in self.filters:
filter.apply_filter(query)


def insert_row_into_table(table, row):
"""
Insert the given row into its table
Expand Down Expand Up @@ -264,12 +296,14 @@ def get_rows_by_filter(table, filters):
:return: A list of the rows returned in dictionary form
"""
query = ReadQuery(table)
filter_handler = FilterOrderHandler()
try:
for query_filter in filters:
if len(query_filter) == 0:
pass
else:
QueryFilterFactory.get_query_filter(query_filter).apply_filter(query)
filter_handler.add_filter(QueryFilterFactory.get_query_filter(query_filter))
filter_handler.apply_filters(query)
results = query.get_all_results()
if query.include_related_entities:
for query_filter in filters:
Expand Down Expand Up @@ -301,11 +335,13 @@ def get_filtered_row_count(table, filters):

log.info(f" getting count for {table.__tablename__}")
count_query = CountQuery(table)
for filter in filters:
if len(filter) == 0:
filter_handler = FilterOrderHandler()
for query_filter in filters:
if len(query_filter) == 0:
pass
else:
QueryFilterFactory.get_query_filter(filter).apply_filter(count_query)
filter_handler.add_filter(QueryFilterFactory.get_query_filter(query_filter))
filter_handler.apply_filters(count_query)
return count_query.get_count()


Expand Down

0 comments on commit 2f59010

Please sign in to comment.