From f1e6e8ac1978b0703fac5ed072574fc092aef1eb Mon Sep 17 00:00:00 2001 From: Keiran Price Date: Thu, 17 Oct 2019 12:44:17 +0100 Subject: [PATCH] #42: Add way to filter included entities --- common/database_helpers.py | 22 ++++++++++++++++------ 1 file changed, 16 insertions(+), 6 deletions(-) diff --git a/common/database_helpers.py b/common/database_helpers.py index 4c2c0f34..edf62ecd 100644 --- a/common/database_helpers.py +++ b/common/database_helpers.py @@ -6,6 +6,7 @@ from sqlalchemy.orm import aliased from common.exceptions import MissingRecordError, BadFilterError, BadRequestError, MultipleIncludeError +from common.models import db_models from common.models.db_models import INVESTIGATIONUSER, INVESTIGATION, INSTRUMENT, FACILITYCYCLE, \ INVESTIGATIONINSTRUMENT, FACILITY from common.session_manager import session_manager @@ -146,22 +147,31 @@ class WhereFilter(QueryFilter): precedence = 1 def __init__(self, field, value, operation): - self.field = field + self.field = field if "." not in field else field.split(".")[0] + self.included_field = None if "." not in field else field.split(".")[1] self.value = value self.operation = operation def apply_filter(self, query): + try: + field = getattr(query.table, self.field) + except AttributeError: + raise BadFilterError(f"Bad WhereFilter requested") + if self.included_field: + included_table = getattr(db_models, self.field) + query.base_query = query.base_query.join(included_table) + field = getattr(included_table, self.included_field) if self.operation == "eq": - query.base_query = query.base_query.filter(getattr(query.table, self.field) == self.value) + query.base_query = query.base_query.filter(field == self.value) elif self.operation == "like": - query.base_query = query.base_query.filter(getattr(query.table, self.field).like(f"%{self.value}%")) + query.base_query = query.base_query.filter(field.like(f"%{self.value}%")) elif self.operation == "lte": - query.base_query = query.base_query.filter(getattr(query.table, self.field) <= self.value) + query.base_query = query.base_query.filter(field <= self.value) elif self.operation == "gte": - query.base_query = query.base_query.filter(getattr(query.table, self.field) >= self.value) + query.base_query = query.base_query.filter(field >= self.value) elif self.operation == "in": - query.base_query = query.base_query.filter(getattr(query.table, self.field).in_(self.value)) + query.base_query = query.base_query.filter(field.in_(self.value)) else: raise BadFilterError(f" Bad operation given to where filter. operation: {self.operation}")