Skip to content

Commit

Permalink
Merge pull request #233 from ral-facilities/feature/distinct-filter-r…
Browse files Browse the repository at this point in the history
…elated-entities-#223

Allow related entities on DB distinct filter
  • Loading branch information
MRichards99 authored May 19, 2021
2 parents 5901e03 + da118af commit 205f155
Show file tree
Hide file tree
Showing 8 changed files with 484 additions and 173 deletions.
140 changes: 106 additions & 34 deletions datagateway_api/common/database/filters.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,6 @@

from sqlalchemy import asc, desc

from datagateway_api.common.database import models
from datagateway_api.common.exceptions import FilterError, MultipleIncludeError
from datagateway_api.common.filters import (
DistinctFieldFilter,
Expand All @@ -18,58 +17,114 @@
log = logging.getLogger()


class DatabaseWhereFilter(WhereFilter):
def __init__(self, field, value, operation):
super().__init__(field, value, operation)
class DatabaseFilterUtilities:
"""
Class containing utility functions used in the WhereFilter and DistinctFilter
In this class, the terminology of 'included entities' has been made more generic to
'related entities'. When these functions are used with the WhereFilter, the related
entities are in fact included entities (entities which are also present in the input
of an include filter in the same request). However, when these functions are used
with the DistinctFilter, they are related entities, not included entities as there's
no requirement for the entity names to also be present in an include filter (this is
to match the ICAT backend)
"""

self.included_field = None
self.included_included_field = None
self._extract_filter_fields(field)
def __init__(self):
self.field = None
self.related_field = None
self.related_related_field = None

def _extract_filter_fields(self, field):
def extract_filter_fields(self, field):
"""
Extract the related fields names and put them into separate variables
:param field: ICAT field names, separated by dots
:type field: :class:`str`
:raises ValueError: If the maximum related/included depth is exceeded
"""

# Flushing fields in case they have been previously set
self.field = None
self.related_field = None
self.related_related_field = None

fields = field.split(".")
include_depth = len(fields)
related_depth = len(fields)

log.debug("Fields: %s, Include Depth: %d", fields, include_depth)
log.debug("Fields: %s, Related Depth: %d", fields, related_depth)

if include_depth == 1:
if related_depth == 1:
self.field = fields[0]
elif include_depth == 2:
self.distinct_join_flag = True
elif related_depth == 2:
self.field = fields[0]
self.included_field = fields[1]
elif include_depth == 3:
self.related_field = fields[1]
elif related_depth == 3:
self.field = fields[0]
self.included_field = fields[1]
self.included_included_field = fields[2]
self.related_field = fields[1]
self.related_related_field = fields[2]
else:
raise ValueError(f"Maximum include depth exceeded. {field}'s depth > 3")
raise ValueError(f"Maximum related depth exceeded. {field}'s depth > 3")

def apply_filter(self, query):
try:
field = getattr(query.table, self.field)
except AttributeError:
raise FilterError(
f"Unknown attribute {self.field} on table {query.table.__name__}",
)
def add_query_join(self, query):
"""
Adds any required JOINs to the query if any related fields have been used in the
filter
:param query: The query to have filters applied to
:type query: :class:`datagateway_api.common.database.helpers.[QUERY]`
"""

if self.included_included_field:
included_table = getattr(models, self.field)
included_included_table = getattr(models, self.included_field)
if self.related_related_field:
included_table = get_entity_object_from_name(self.field)
included_included_table = get_entity_object_from_name(self.related_field)
query.base_query = query.base_query.join(included_table).join(
included_included_table,
)
field = getattr(included_included_table, self.included_included_field)
elif self.included_field:
elif self.related_field:
included_table = get_entity_object_from_name(self.field)
query.base_query = query.base_query.join(included_table)
field = getattr(included_table, self.included_field)

def get_entity_model_for_filter(self, query):
"""
Fetches the appropriate entity model based on the contents of the instance
variables of this class
:param query: The query to have filters applied to
:type query: :class:`datagateway_api.common.database.helpers.[QUERY]`
:return: Entity model of the field (usually the field relating to the endpoint
the request is coming from)
"""
if self.related_related_field:
included_included_table = get_entity_object_from_name(self.related_field)
field = self._get_field(included_included_table, self.related_related_field)
elif self.related_field:
included_table = get_entity_object_from_name(self.field)
field = self._get_field(included_table, self.related_field)
else:
# No related fields
field = self._get_field(query.table, self.field)

return field

def _get_field(self, table, field):
try:
return getattr(table, field)
except AttributeError:
raise FilterError(f"Unknown attribute {field} on table {table.__name__}")


class DatabaseWhereFilter(WhereFilter, DatabaseFilterUtilities):
def __init__(self, field, value, operation):
WhereFilter.__init__(self, field, value, operation)
DatabaseFilterUtilities.__init__(self)

self.extract_filter_fields(field)

def apply_filter(self, query):
self.add_query_join(query)
field = self.get_entity_model_for_filter(query)

if self.operation == "eq":
query.base_query = query.base_query.filter(field == self.value)
Expand All @@ -95,17 +150,34 @@ def apply_filter(self, query):
)


class DatabaseDistinctFieldFilter(DistinctFieldFilter):
class DatabaseDistinctFieldFilter(DistinctFieldFilter, DatabaseFilterUtilities):
def __init__(self, fields):
super().__init__(fields)
DistinctFieldFilter.__init__(self, fields)
DatabaseFilterUtilities.__init__(self)

def apply_filter(self, query):
query.is_distinct_fields_query = True

try:
self.fields = [getattr(query.table, field) for field in self.fields]
distinct_fields = []
for field_name in self.fields:
self.extract_filter_fields(field_name)
distinct_fields.append(self.get_entity_model_for_filter(query))

# Base query must be set to a DISTINCT query before adding JOINs - if these
# actions are done in the opposite order, the JOINs will overwrite the
# SELECT multiple and effectively turn the query into a `SELECT *`
query.base_query = (
query.session.query(*distinct_fields)
.select_from(query.table)
.distinct()
)

for field_name in self.fields:
self.extract_filter_fields(field_name)
self.add_query_join(query)
except AttributeError:
raise FilterError("Bad field requested")
query.base_query = query.session.query(*self.fields).distinct()


class DatabaseOrderFilter(OrderFilter):
Expand Down
14 changes: 11 additions & 3 deletions datagateway_api/common/database/helpers.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@
from sqlalchemy.orm import aliased

from datagateway_api.common.database.filters import (
DatabaseDistinctFieldFilter,
DatabaseIncludeFilter as IncludeFilter,
DatabaseWhereFilter as WhereFilter,
)
Expand All @@ -24,6 +25,7 @@
MissingRecordError,
)
from datagateway_api.common.filter_order_handler import FilterOrderHandler
from datagateway_api.common.helpers import map_distinct_attributes_to_results


log = logging.getLogger()
Expand Down Expand Up @@ -278,7 +280,7 @@ def get_filtered_read_query_results(filter_handler, filters, query):
filter_handler.apply_filters(query)
results = query.get_all_results()
if query.is_distinct_fields_query:
return _get_distinct_fields_as_dicts(results)
return _get_distinct_fields_as_dicts(filters, results)
if query.include_related_entities:
return _get_results_with_include(filters, results)
return list(map(lambda x: x.to_dict(), results))
Expand All @@ -298,18 +300,24 @@ def _get_results_with_include(filters, results):
return [x.to_nested_dict(query_filter.included_filters) for x in results]


def _get_distinct_fields_as_dicts(results):
def _get_distinct_fields_as_dicts(filters, results):
"""
Given a list of column results return a list of dictionaries where each column name
is the key and the column value is the dictionary key value
:param results: A list of sql alchemy result objects
:return: A list of dictionary representations of the sqlalchemy result objects
"""
distinct_fields = []
for query_filter in filters:
if type(query_filter) is DatabaseDistinctFieldFilter:
distinct_fields.extend(query_filter.fields)

dictionaries = []
for result in results:
dictionary = {k: getattr(result, k) for k in result.keys()}
dictionary = map_distinct_attributes_to_results(distinct_fields, result)
dictionaries.append(dictionary)

return dictionaries


Expand Down
76 changes: 76 additions & 0 deletions datagateway_api/common/helpers.py
Original file line number Diff line number Diff line change
@@ -1,12 +1,15 @@
from datetime import datetime
from functools import wraps
import json
import logging

from dateutil.tz.tz import tzlocal
from flask import request
from flask_restful import reqparse
from sqlalchemy.exc import IntegrityError

from datagateway_api.common.database import models
from datagateway_api.common.date_handler import DateHandler
from datagateway_api.common.exceptions import (
ApiError,
AuthenticationError,
Expand Down Expand Up @@ -126,3 +129,76 @@ def get_entity_object_from_name(entity_name):
raise ApiError(
f"Entity class cannot be found, missing class for {entity_name}",
)


def map_distinct_attributes_to_results(distinct_attributes, query_result):
"""
Maps the attribute names from a distinct filter onto the results given by the result
of a query
When selecting multiple (but not all) attributes in a database query, the results
are returned in a list and not mapped to an entity object. This means the 'normal'
functions used to process data ready for output (`entity_to_dict()` for the ICAT
backend) cannot be used, as the structure of the query result is different.
:param distinct_attributes: List of distinct attributes from the distinct
filter of the incoming request
:type distinct_attributes: :class:`list`
:param query_result: Results fetched from a database query (backend independent due
to the data structure of this parameter)
:type query_result: :class:`tuple` or :class:`list` when a single attribute is
given from ICAT backend, or :class:`sqlalchemy.engine.row.Row` when used on the
DB backend
:return: Dictionary of attribute names paired with the results, ready to be
returned to the user
"""
result_dict = {}
for attr_name, data in zip(distinct_attributes, query_result):
# Splitting attribute names in case it's from a related entity
split_attr_name = attr_name.split(".")

if isinstance(data, datetime):
# Workaround for when this function is used on DB backend, where usually
# `_make_serializable()` would fix tzinfo
if data.tzinfo is None:
data = data.replace(tzinfo=tzlocal())
data = DateHandler.datetime_object_to_str(data)

# Attribute name is from the 'origin' entity (i.e. not a related entity)
if len(split_attr_name) == 1:
result_dict[attr_name] = data
# Attribute name is a related entity, dictionary needs to be nested
else:
result_dict.update(map_nested_attrs({}, split_attr_name, data))

return result_dict


def map_nested_attrs(nested_dict, split_attr_name, query_data):
"""
A function that can be called recursively to map attributes from related
entities to the associated data
:param nested_dict: Dictionary to insert data into
:type nested_dict: :class:`dict`
:param split_attr_name: List of parts to an attribute name, that have been split
by "."
:type split_attr_name: :class:`list`
:param query_data: Data to be added to the dictionary
:type query_data: :class:`str` or :class:`str`
:return: Dictionary to be added to the result dictionary
"""
# Popping LHS of related attribute name to see if it's an attribute name or part
# of a path to a related entity
attr_name_pop = split_attr_name.pop(0)

# Related attribute name, ready to insert data into dictionary
if len(split_attr_name) == 0:
# at role, so put data in
nested_dict[attr_name_pop] = query_data
# Part of the path for related entity, need to recurse to get to attribute name
else:
nested_dict[attr_name_pop] = {}
map_nested_attrs(nested_dict[attr_name_pop], split_attr_name, query_data)

return nested_dict
Loading

0 comments on commit 205f155

Please sign in to comment.