Skip to content

Commit

Permalink
Merge pull request #96 from ral-facilities/42_add_filtering_to_includes
Browse files Browse the repository at this point in the history
Add way to filter included entities
  • Loading branch information
keiranjprice101 authored Oct 18, 2019
2 parents c74c0cf + cb7fbf2 commit 8191110
Showing 1 changed file with 35 additions and 5 deletions.
40 changes: 35 additions & 5 deletions common/database_helpers.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@
from sqlalchemy.orm import aliased

from common.exceptions import MissingRecordError, BadFilterError, BadRequestError, MultipleIncludeError
from common.models import db_models
from common.models.db_models import INVESTIGATIONUSER, INVESTIGATION, INSTRUMENT, FACILITYCYCLE, \
INVESTIGATIONINSTRUMENT, FACILITY
from common.session_manager import session_manager
Expand Down Expand Up @@ -147,21 +148,50 @@ class WhereFilter(QueryFilter):

def __init__(self, field, value, operation):
self.field = field
self.included_field = None
self.included_included_field = None
self._set_filter_fields()
self.value = value
self.operation = operation

def _set_filter_fields(self):
if self.field.count(".") == 1:
self.included_field = self.field.split(".")[1]
self.field = self.field.split(".")[0]

if self.field.count(".") == 2:
self.included_included_field = self.field.split(".")[2]
self.included_field = self.field.split(".")[1]
self.field = self.field.split(".")[0]


def apply_filter(self, query):
try:
field = getattr(query.table, self.field)
except AttributeError:
raise BadFilterError(f"Bad WhereFilter requested")

if self.included_included_field:
included_table = getattr(db_models, self.field)
included_included_table = getattr(db_models, self.included_field)
query.base_query = query.base_query.join(included_table).join(included_included_table)
field = getattr(included_included_table, self.included_included_field)

elif self.included_field:
included_table = getattr(db_models, self.field)
query.base_query = query.base_query.join(included_table)
field = getattr(included_table, self.included_field)

if self.operation == "eq":
query.base_query = query.base_query.filter(getattr(query.table, self.field) == self.value)
query.base_query = query.base_query.filter(field == self.value)
elif self.operation == "like":
query.base_query = query.base_query.filter(getattr(query.table, self.field).like(f"%{self.value}%"))
query.base_query = query.base_query.filter(field.like(f"%{self.value}%"))
elif self.operation == "lte":
query.base_query = query.base_query.filter(getattr(query.table, self.field) <= self.value)
query.base_query = query.base_query.filter(field <= self.value)
elif self.operation == "gte":
query.base_query = query.base_query.filter(getattr(query.table, self.field) >= self.value)
query.base_query = query.base_query.filter(field >= self.value)
elif self.operation == "in":
query.base_query = query.base_query.filter(getattr(query.table, self.field).in_(self.value))
query.base_query = query.base_query.filter(field.in_(self.value))
else:
raise BadFilterError(f" Bad operation given to where filter. operation: {self.operation}")

Expand Down

0 comments on commit 8191110

Please sign in to comment.