Skip to content

Commit

Permalink
#42: Add way to filter included entities
Browse files Browse the repository at this point in the history
  • Loading branch information
keiranjprice101 committed Oct 17, 2019
1 parent c74c0cf commit f1e6e8a
Showing 1 changed file with 16 additions and 6 deletions.
22 changes: 16 additions & 6 deletions common/database_helpers.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@
from sqlalchemy.orm import aliased

from common.exceptions import MissingRecordError, BadFilterError, BadRequestError, MultipleIncludeError
from common.models import db_models
from common.models.db_models import INVESTIGATIONUSER, INVESTIGATION, INSTRUMENT, FACILITYCYCLE, \
INVESTIGATIONINSTRUMENT, FACILITY
from common.session_manager import session_manager
Expand Down Expand Up @@ -146,22 +147,31 @@ class WhereFilter(QueryFilter):
precedence = 1

def __init__(self, field, value, operation):
self.field = field
self.field = field if "." not in field else field.split(".")[0]
self.included_field = None if "." not in field else field.split(".")[1]
self.value = value
self.operation = operation

def apply_filter(self, query):
try:
field = getattr(query.table, self.field)
except AttributeError:
raise BadFilterError(f"Bad WhereFilter requested")
if self.included_field:
included_table = getattr(db_models, self.field)
query.base_query = query.base_query.join(included_table)
field = getattr(included_table, self.included_field)

if self.operation == "eq":
query.base_query = query.base_query.filter(getattr(query.table, self.field) == self.value)
query.base_query = query.base_query.filter(field == self.value)
elif self.operation == "like":
query.base_query = query.base_query.filter(getattr(query.table, self.field).like(f"%{self.value}%"))
query.base_query = query.base_query.filter(field.like(f"%{self.value}%"))
elif self.operation == "lte":
query.base_query = query.base_query.filter(getattr(query.table, self.field) <= self.value)
query.base_query = query.base_query.filter(field <= self.value)
elif self.operation == "gte":
query.base_query = query.base_query.filter(getattr(query.table, self.field) >= self.value)
query.base_query = query.base_query.filter(field >= self.value)
elif self.operation == "in":
query.base_query = query.base_query.filter(getattr(query.table, self.field).in_(self.value))
query.base_query = query.base_query.filter(field.in_(self.value))
else:
raise BadFilterError(f" Bad operation given to where filter. operation: {self.operation}")

Expand Down

0 comments on commit f1e6e8a

Please sign in to comment.