Skip to content

Commit

Permalink
#223: Implement DatabaseFilterUtilities into DatabaseWhereFilter
Browse files Browse the repository at this point in the history
- Similar implemenation as DatabaseDistinctFieldFilter
  • Loading branch information
MRichards99 committed May 6, 2021
1 parent 1b7f132 commit 11f75a2
Showing 1 changed file with 5 additions and 50 deletions.
55 changes: 5 additions & 50 deletions datagateway_api/common/database/filters.py
Original file line number Diff line number Diff line change
Expand Up @@ -61,9 +61,6 @@ def _extract_filter_fields(self, field):
else:
raise ValueError(f"Maximum related depth exceeded. {field}'s depth > 3")

# TODO - Remove if not needed
# return (field, related_field, related_related_field)

def _add_query_join(self, query):
"""
Fetches the appropriate entity model based on the contents of `self.field` and
Expand Down Expand Up @@ -97,58 +94,16 @@ def _add_query_join(self, query):
return field


class DatabaseWhereFilter(WhereFilter):
class DatabaseWhereFilter(WhereFilter, DatabaseFilterUtilities):
def __init__(self, field, value, operation):
super().__init__(field, value, operation)
# TODO - Apply any 'pythonic' solution here too
WhereFilter.__init__(self, field, value, operation)
DatabaseFilterUtilities.__init__(self)

self.included_field = None
self.included_included_field = None
self._extract_filter_fields(field)

def _extract_filter_fields(self, field):
"""
Extract the related fields names and put them into separate variables
:param field: ICAT field names, separated by dots
:type field: :class:`str`
"""

fields = field.split(".")
include_depth = len(fields)

log.debug("Fields: %s, Include Depth: %d", fields, include_depth)

if include_depth == 1:
self.field = fields[0]
elif include_depth == 2:
self.field = fields[0]
self.included_field = fields[1]
elif include_depth == 3:
self.field = fields[0]
self.included_field = fields[1]
self.included_included_field = fields[2]
else:
raise ValueError(f"Maximum include depth exceeded. {field}'s depth > 3")

def apply_filter(self, query):
try:
field = getattr(query.table, self.field)
except AttributeError:
raise FilterError(
f"Unknown attribute {self.field} on table {query.table.__name__}",
)

if self.included_included_field:
included_table = getattr(models, self.field)
included_included_table = getattr(models, self.included_field)
query.base_query = query.base_query.join(included_table).join(
included_included_table,
)
field = getattr(included_included_table, self.included_included_field)
elif self.included_field:
included_table = get_entity_object_from_name(self.field)
query.base_query = query.base_query.join(included_table)
field = getattr(included_table, self.included_field)
field = self._add_query_join(query)

if self.operation == "eq":
query.base_query = query.base_query.filter(field == self.value)
Expand Down

0 comments on commit 11f75a2

Please sign in to comment.