Skip to content

Commit

Permalink
Merge pull request #37 from ral-facilities/18_add_more_where_filterin…
Browse files Browse the repository at this point in the history
…g_types

Add Like, less than and greater than where filtering.
  • Loading branch information
keiranjprice101 authored Aug 15, 2019
2 parents fa374f7 + b2d2750 commit 7b6adef
Showing 1 changed file with 22 additions and 5 deletions.
27 changes: 22 additions & 5 deletions common/database_helpers.py
Original file line number Diff line number Diff line change
Expand Up @@ -100,6 +100,7 @@ def __init__(self, table, row, new_values):
def execute_query(self):
log.info(f" Updating row in {self.table}")
self.row.update_from_dict(self.new_values)
self.session.add(self.row)
self.commit_changes()


Expand All @@ -124,12 +125,22 @@ def apply_filter(self, query):
class WhereFilter(QueryFilter):
precedence = 0

def __init__(self, field, value):
def __init__(self, field, value, operation):
self.field = field
self.value = value
self.operation = operation

def apply_filter(self, query):
query.base_query = query.base_query.filter(getattr(query.table, self.field) == self.value)
if self.operation == "eq":
query.base_query = query.base_query.filter(getattr(query.table, self.field) == self.value)
elif self.operation == "like":
query.base_query = query.base_query.filter(getattr(query.table, self.field).like(f"%{self.value}%"))
elif self.operation == "lte":
query.base_query = query.base_query.filter(getattr(query.table, self.field) <= self.value)
elif self.operation == "gte":
query.base_query = query.base_query.filter(getattr(query.table, self.field) >= self.value)
else:
raise BadFilterError(f" Bad operation given to where filter. operation: {self.operation}")


class OrderFilter(QueryFilter):
Expand Down Expand Up @@ -188,9 +199,14 @@ def get_query_filter(filter):
"""
filter_name = list(filter)[0].lower()
if filter_name == "where":
return WhereFilter(list(filter["where"])[0], filter["where"][list(filter["where"])[0]])
field = list(filter[filter_name].keys())[0]
operation = list(filter[filter_name][field].keys())[0]
value = filter[filter_name][field][operation]
return WhereFilter(field, value, operation)
elif filter_name == "order":
return OrderFilter(filter["order"].split(" ")[0], filter["order"].split(" ")[1])
field = filter["order"].split(" ")[0]
direction = filter["order"].split(" ")[1]
return OrderFilter(field, direction)
elif filter_name == "skip":
return SkipFilter(filter["skip"])
elif filter_name == "limit":
Expand Down Expand Up @@ -230,6 +246,7 @@ def apply_filters(self, query):
def insert_row_into_table(table, row):
"""
Insert the given row into its table
:param table: The table to be inserted to
:param row: The row to be inserted
"""
create_query = CreateQuery(table, row)
Expand Down Expand Up @@ -257,7 +274,7 @@ def get_row_by_id(table, id):
read_query = ReadQuery(table)
try:
log.info(f" Querying {table.__tablename__} for record with ID: {id}")
where_filter = WhereFilter("ID", id)
where_filter = WhereFilter("ID", id, "eq")
where_filter.apply_filter(read_query)
return read_query.get_single_result()
finally:
Expand Down

0 comments on commit 7b6adef

Please sign in to comment.