From 11f75a20a2dba6aac446fbaeb8d61a3f8a436233 Mon Sep 17 00:00:00 2001 From: Matthew Richards Date: Thu, 6 May 2021 17:49:22 +0000 Subject: [PATCH] #223: Implement DatabaseFilterUtilities into DatabaseWhereFilter - Similar implemenation as DatabaseDistinctFieldFilter --- datagateway_api/common/database/filters.py | 55 ++-------------------- 1 file changed, 5 insertions(+), 50 deletions(-) diff --git a/datagateway_api/common/database/filters.py b/datagateway_api/common/database/filters.py index b2545cd7..6bb41031 100644 --- a/datagateway_api/common/database/filters.py +++ b/datagateway_api/common/database/filters.py @@ -61,9 +61,6 @@ def _extract_filter_fields(self, field): else: raise ValueError(f"Maximum related depth exceeded. {field}'s depth > 3") - # TODO - Remove if not needed - # return (field, related_field, related_related_field) - def _add_query_join(self, query): """ Fetches the appropriate entity model based on the contents of `self.field` and @@ -97,58 +94,16 @@ def _add_query_join(self, query): return field -class DatabaseWhereFilter(WhereFilter): +class DatabaseWhereFilter(WhereFilter, DatabaseFilterUtilities): def __init__(self, field, value, operation): - super().__init__(field, value, operation) + # TODO - Apply any 'pythonic' solution here too + WhereFilter.__init__(self, field, value, operation) + DatabaseFilterUtilities.__init__(self) - self.included_field = None - self.included_included_field = None self._extract_filter_fields(field) - def _extract_filter_fields(self, field): - """ - Extract the related fields names and put them into separate variables - - :param field: ICAT field names, separated by dots - :type field: :class:`str` - """ - - fields = field.split(".") - include_depth = len(fields) - - log.debug("Fields: %s, Include Depth: %d", fields, include_depth) - - if include_depth == 1: - self.field = fields[0] - elif include_depth == 2: - self.field = fields[0] - self.included_field = fields[1] - elif include_depth == 3: - self.field = fields[0] - self.included_field = fields[1] - self.included_included_field = fields[2] - else: - raise ValueError(f"Maximum include depth exceeded. {field}'s depth > 3") - def apply_filter(self, query): - try: - field = getattr(query.table, self.field) - except AttributeError: - raise FilterError( - f"Unknown attribute {self.field} on table {query.table.__name__}", - ) - - if self.included_included_field: - included_table = getattr(models, self.field) - included_included_table = getattr(models, self.included_field) - query.base_query = query.base_query.join(included_table).join( - included_included_table, - ) - field = getattr(included_included_table, self.included_included_field) - elif self.included_field: - included_table = get_entity_object_from_name(self.field) - query.base_query = query.base_query.join(included_table) - field = getattr(included_table, self.included_field) + field = self._add_query_join(query) if self.operation == "eq": query.base_query = query.base_query.filter(field == self.value)