From 1b7f1325c433cf0917bdbc8e557468a1f110296e Mon Sep 17 00:00:00 2001 From: Matthew Richards Date: Thu, 6 May 2021 17:29:38 +0000 Subject: [PATCH 01/20] #223: Allow DatabaseDistinctFilter to recognise related entity inputs - This takes existing code from the DatabaseWhereFilter and makes it generic. Future commits will get the WhereFilter to also use this generic version in the same way as the DistinctFilter --- datagateway_api/common/database/filters.py | 94 +++++++++++++++++++++- 1 file changed, 90 insertions(+), 4 deletions(-) diff --git a/datagateway_api/common/database/filters.py b/datagateway_api/common/database/filters.py index 06209c22..b2545cd7 100644 --- a/datagateway_api/common/database/filters.py +++ b/datagateway_api/common/database/filters.py @@ -18,6 +18,85 @@ log = logging.getLogger() +class DatabaseFilterUtilities: + """ + Class containing utility functions used in the WhereFilter and DistinctFilter + + In this class, the terminology of 'included entities' has been made more generic to + 'related entities'. When these functions are used with the WhereFilter, the related + entities are in fact included entities (entities which are also present in the input + of an include filter in the same request). However, when these functions are used + with the DistinctFilter, they are related entities, not included entities as there's + no requirement for the entity names to also be present in an include filter (this is + to match the ICAT backend) + """ + + def __init__(self): + self.related_field = None + self.related_related_field = None + + def _extract_filter_fields(self, field): + """ + Extract the related fields names and put them into separate variables + + :param field: ICAT field names, separated by dots + :type field: :class:`str` + :raises ValueError: If the maximum related/included depth is exceeded + """ + + fields = field.split(".") + related_depth = len(fields) + + log.debug("Fields: %s, Related Depth: %d", fields, related_depth) + + if related_depth == 1: + self.field = fields[0] + elif related_depth == 2: + self.field = fields[0] + self.related_field = fields[1] + elif related_depth == 3: + self.field = fields[0] + self.related_field = fields[1] + self.related_related_field = fields[2] + else: + raise ValueError(f"Maximum related depth exceeded. {field}'s depth > 3") + + # TODO - Remove if not needed + # return (field, related_field, related_related_field) + + def _add_query_join(self, query): + """ + Fetches the appropriate entity model based on the contents of `self.field` and + adds any required JOINs to the query if any related fields have been used in the + filter + + :param query: The query to have filters applied to + :type query: :class:`datagateway_api.common.database.helpers.[QUERY]` + :return: Entity model of the field (usually the field relating to the endpoint + the request is coming from) + """ + try: + field = getattr(query.table, self.field) + except AttributeError: + raise FilterError( + f"Unknown attribute {self.field} on table {query.table.__name__}", + ) + + if self.related_related_field: + included_table = getattr(models, self.field) + included_included_table = getattr(models, self.related_field) + query.base_query = query.base_query.join(included_table).join( + included_included_table, + ) + field = getattr(included_included_table, self.related_related_field) + elif self.related_field: + included_table = get_entity_object_from_name(self.field) + query.base_query = query.base_query.join(included_table) + field = getattr(included_table, self.related_field) + + return field + + class DatabaseWhereFilter(WhereFilter): def __init__(self, field, value, operation): super().__init__(field, value, operation) @@ -95,17 +174,24 @@ def apply_filter(self, query): ) -class DatabaseDistinctFieldFilter(DistinctFieldFilter): +class DatabaseDistinctFieldFilter(DistinctFieldFilter, DatabaseFilterUtilities): def __init__(self, fields): - super().__init__(fields) + # TODO - what's the Pythonic solution here? + # super().__init__(fields) + DistinctFieldFilter.__init__(self, fields) + DatabaseFilterUtilities.__init__(self) def apply_filter(self, query): query.is_distinct_fields_query = True try: - self.fields = [getattr(query.table, field) for field in self.fields] + distinct_fields = [] + for field_name in self.fields: + self._extract_filter_fields(field_name) + field = self._add_query_join(query) + distinct_fields.append(field) except AttributeError: raise FilterError("Bad field requested") - query.base_query = query.session.query(*self.fields).distinct() + query.base_query = query.session.query(*distinct_fields).distinct() class DatabaseOrderFilter(OrderFilter): From 11f75a20a2dba6aac446fbaeb8d61a3f8a436233 Mon Sep 17 00:00:00 2001 From: Matthew Richards Date: Thu, 6 May 2021 17:49:22 +0000 Subject: [PATCH 02/20] #223: Implement DatabaseFilterUtilities into DatabaseWhereFilter - Similar implemenation as DatabaseDistinctFieldFilter --- datagateway_api/common/database/filters.py | 55 ++-------------------- 1 file changed, 5 insertions(+), 50 deletions(-) diff --git a/datagateway_api/common/database/filters.py b/datagateway_api/common/database/filters.py index b2545cd7..6bb41031 100644 --- a/datagateway_api/common/database/filters.py +++ b/datagateway_api/common/database/filters.py @@ -61,9 +61,6 @@ def _extract_filter_fields(self, field): else: raise ValueError(f"Maximum related depth exceeded. {field}'s depth > 3") - # TODO - Remove if not needed - # return (field, related_field, related_related_field) - def _add_query_join(self, query): """ Fetches the appropriate entity model based on the contents of `self.field` and @@ -97,58 +94,16 @@ def _add_query_join(self, query): return field -class DatabaseWhereFilter(WhereFilter): +class DatabaseWhereFilter(WhereFilter, DatabaseFilterUtilities): def __init__(self, field, value, operation): - super().__init__(field, value, operation) + # TODO - Apply any 'pythonic' solution here too + WhereFilter.__init__(self, field, value, operation) + DatabaseFilterUtilities.__init__(self) - self.included_field = None - self.included_included_field = None self._extract_filter_fields(field) - def _extract_filter_fields(self, field): - """ - Extract the related fields names and put them into separate variables - - :param field: ICAT field names, separated by dots - :type field: :class:`str` - """ - - fields = field.split(".") - include_depth = len(fields) - - log.debug("Fields: %s, Include Depth: %d", fields, include_depth) - - if include_depth == 1: - self.field = fields[0] - elif include_depth == 2: - self.field = fields[0] - self.included_field = fields[1] - elif include_depth == 3: - self.field = fields[0] - self.included_field = fields[1] - self.included_included_field = fields[2] - else: - raise ValueError(f"Maximum include depth exceeded. {field}'s depth > 3") - def apply_filter(self, query): - try: - field = getattr(query.table, self.field) - except AttributeError: - raise FilterError( - f"Unknown attribute {self.field} on table {query.table.__name__}", - ) - - if self.included_included_field: - included_table = getattr(models, self.field) - included_included_table = getattr(models, self.included_field) - query.base_query = query.base_query.join(included_table).join( - included_included_table, - ) - field = getattr(included_included_table, self.included_included_field) - elif self.included_field: - included_table = get_entity_object_from_name(self.field) - query.base_query = query.base_query.join(included_table) - field = getattr(included_table, self.included_field) + field = self._add_query_join(query) if self.operation == "eq": query.base_query = query.base_query.filter(field == self.value) From 4bc637a158536246b6c3ce2d4a4f63256edd8a05 Mon Sep 17 00:00:00 2001 From: Matthew Richards Date: Tue, 11 May 2021 09:20:18 +0000 Subject: [PATCH 03/20] #223: Allow related entities to be given in ICAT schema form --- datagateway_api/common/database/filters.py | 23 ++++++++++++---------- 1 file changed, 13 insertions(+), 10 deletions(-) diff --git a/datagateway_api/common/database/filters.py b/datagateway_api/common/database/filters.py index 6bb41031..ffce968f 100644 --- a/datagateway_api/common/database/filters.py +++ b/datagateway_api/common/database/filters.py @@ -72,27 +72,30 @@ def _add_query_join(self, query): :return: Entity model of the field (usually the field relating to the endpoint the request is coming from) """ - try: - field = getattr(query.table, self.field) - except AttributeError: - raise FilterError( - f"Unknown attribute {self.field} on table {query.table.__name__}", - ) if self.related_related_field: - included_table = getattr(models, self.field) - included_included_table = getattr(models, self.related_field) + included_table = get_entity_object_from_name(self.field) + included_included_table = get_entity_object_from_name(self.related_field) query.base_query = query.base_query.join(included_table).join( included_included_table, ) - field = getattr(included_included_table, self.related_related_field) + field = self._get_field(included_included_table, self.related_related_field) elif self.related_field: included_table = get_entity_object_from_name(self.field) query.base_query = query.base_query.join(included_table) - field = getattr(included_table, self.related_field) + field = self._get_field(included_table, self.related_field) + else: + # No related fields + field = self._get_field(query.table, self.field) return field + def _get_field(self, table, field): + try: + return getattr(table, field) + except AttributeError: + raise FilterError(f"Unknown attribute {field} on table {table.__name__}") + class DatabaseWhereFilter(WhereFilter, DatabaseFilterUtilities): def __init__(self, field, value, operation): From e67ee26b4c291ad7fe2feb5285c2c21586f7556b Mon Sep 17 00:00:00 2001 From: Matthew Richards Date: Tue, 11 May 2021 13:45:24 +0000 Subject: [PATCH 04/20] #223: Move map distinct attrs to results to common.helpers - These two functions will now also be used in the DB backend, so should be moved to a more common location --- datagateway_api/common/helpers.py | 72 +++++++++++++++++++++++++++ datagateway_api/common/icat/query.py | 74 +--------------------------- 2 files changed, 74 insertions(+), 72 deletions(-) diff --git a/datagateway_api/common/helpers.py b/datagateway_api/common/helpers.py index c9ce2b22..3012265e 100644 --- a/datagateway_api/common/helpers.py +++ b/datagateway_api/common/helpers.py @@ -1,3 +1,5 @@ +from datagateway_api.common.date_handler import DateHandler +from datetime import datetime from functools import wraps import json import logging @@ -126,3 +128,73 @@ def get_entity_object_from_name(entity_name): raise ApiError( f"Entity class cannot be found, missing class for {entity_name}", ) + + +def map_distinct_attributes_to_results(distinct_attributes, query_result): + """ + Maps the attribute names from a distinct filter onto the results given by the result + of a query + + When selecting multiple (but not all) attributes in a database query, the results + are returned in a list and not mapped to an entity object. This means the 'normal' + functions used to process data ready for output (`entity_to_dict()` for the ICAT + backend) cannot be used, as the structure of the query result is different. + + :param distinct_attributes: List of distinct attributes from the distinct + filter of the incoming request + :type distinct_attributes: :class:`list` + :param query_result: Results fetched from a database query (backend independent due + to the data structure of this parameter) + :type query_result: :class:`tuple` or :class:`list` when a single attribute is + given from ICAT backend, TODO + :return: Dictionary of attribute names paired with the results, ready to be + returned to the user + """ + log.debug(f"Query Result Type: {type(query_result)}") + + result_dict = {} + for attr_name, data in zip(distinct_attributes, query_result): + # Splitting attribute names in case it's from a related entity + split_attr_name = attr_name.split(".") + + if isinstance(data, datetime): + data = DateHandler.datetime_object_to_str(data) + + # Attribute name is from the 'origin' entity (i.e. not a related entity) + if len(split_attr_name) == 1: + result_dict[attr_name] = data + # Attribute name is a related entity, dictionary needs to be nested + else: + result_dict.update(map_nested_attrs({}, split_attr_name, data)) + + return result_dict + + +def map_nested_attrs(nested_dict, split_attr_name, query_data): + """ + A function that can be called recursively to map attributes from related + entities to the associated data + + :param nested_dict: Dictionary to insert data into + :type nested_dict: :class:`dict` + :param split_attr_name: List of parts to an attribute name, that have been split + by "." + :type split_attr_name: :class:`list` + :param query_data: Data to be added to the dictionary + :type query_data: :class:`str` or :class:`str` + :return: Dictionary to be added to the result dictionary + """ + # Popping LHS of related attribute name to see if it's an attribute name or part + # of a path to a related entity + attr_name_pop = split_attr_name.pop(0) + + # Related attribute name, ready to insert data into dictionary + if len(split_attr_name) == 0: + # at role, so put data in + nested_dict[attr_name_pop] = query_data + # Part of the path for related entity, need to recurse to get to attribute name + else: + nested_dict[attr_name_pop] = {} + map_nested_attrs(nested_dict[attr_name_pop], split_attr_name, query_data) + + return nested_dict diff --git a/datagateway_api/common/icat/query.py b/datagateway_api/common/icat/query.py index 97e1af41..791714b5 100644 --- a/datagateway_api/common/icat/query.py +++ b/datagateway_api/common/icat/query.py @@ -1,3 +1,4 @@ +from datagateway_api.common.helpers import map_distinct_attributes_to_results from datetime import datetime import logging @@ -123,9 +124,7 @@ def execute_query(self, client, return_json_formattable=False): # Map distinct attributes and result data.append( - self.map_distinct_attributes_to_results( - distinct_attributes, result, - ), + map_distinct_attributes_to_results(distinct_attributes, result), ) elif not count_query: dict_result = self.entity_to_dict(result, flat_query_includes) @@ -198,75 +197,6 @@ def entity_to_dict(self, entity, includes): d[key] = entity_data return d - def map_distinct_attributes_to_results(self, distinct_attributes, query_result): - """ - Maps the attribute names from a distinct filter onto the results given by the - query constructed and executed using Python ICAT - - When selecting multiple (but not all) attributes in a JPQL query, the results - are returned in a list and not mapped to an entity object. As a result, - `entity_to_dict()` cannot be used as that function assumes an entity object - input. Within the API, selecting multiple attributes happens when a distinct - filter is applied to a request. This function is the alternative for processing - data ready for output - - :param distinct_attributes: List of distinct attributes from the distinct - filter of the incoming request - :type distinct_attributes: :class:`list` - :param query_result: Results fetched from Python ICAT - :type query_result: :class:`tuple` or :class:`list` when a single attribute is - given - :return: Dictionary of attribute names paired with the results, ready to be - returned to the user - """ - result_dict = {} - for attr_name, data in zip(distinct_attributes, query_result): - # Splitting attribute names in case it's from a related entity - split_attr_name = attr_name.split(".") - - if isinstance(data, datetime): - data = DateHandler.datetime_object_to_str(data) - - # Attribute name is from the 'origin' entity (i.e. not a related entity) - if len(split_attr_name) == 1: - result_dict[attr_name] = data - # Attribute name is a related entity, dictionary needs to be nested - else: - result_dict.update(self.map_nested_attrs({}, split_attr_name, data)) - - return result_dict - - def map_nested_attrs(self, nested_dict, split_attr_name, query_data): - """ - A function that can be called recursively to map attributes from related - entities to the associated data - - :param nested_dict: Dictionary to insert data into - :type nested_dict: :class:`dict` - :param split_attr_name: List of parts to an attribute name, that have been split - by "." - :type split_attr_name: :class:`list` - :param query_data: Data to be added to the dictionary - :type query_data: :class:`str` or :class:`str` - :return: Dictionary to be added to the result dictionary - """ - # Popping LHS of related attribute name to see if it's an attribute name or part - # of a path to a related entity - attr_name_pop = split_attr_name.pop(0) - - # Related attribute name, ready to insert data into dictionary - if len(split_attr_name) == 0: - # at role, so put data in - nested_dict[attr_name_pop] = query_data - # Part of the path for related entity, need to recurse to get to attribute name - else: - nested_dict[attr_name_pop] = {} - self.map_nested_attrs( - nested_dict[attr_name_pop], split_attr_name, query_data, - ) - - return nested_dict - def flatten_query_included_fields(self, includes): """ This will take the set of fields included in an ICAT query, split up the fields From dec86b71326248b91843a56f845ec9b268b3527a Mon Sep 17 00:00:00 2001 From: Matthew Richards Date: Tue, 11 May 2021 13:46:40 +0000 Subject: [PATCH 05/20] #223: Allow data from related entities from distinct filters to be nested correctly - This is using the functions moved to common.helpers in the previous commit --- datagateway_api/common/database/helpers.py | 14 +++++++++++--- 1 file changed, 11 insertions(+), 3 deletions(-) diff --git a/datagateway_api/common/database/helpers.py b/datagateway_api/common/database/helpers.py index 39e67ce9..1a4ed8ee 100644 --- a/datagateway_api/common/database/helpers.py +++ b/datagateway_api/common/database/helpers.py @@ -1,4 +1,5 @@ from abc import ABC, abstractmethod +from datagateway_api.common.helpers import map_distinct_attributes_to_results import datetime from functools import wraps import logging @@ -7,6 +8,7 @@ from sqlalchemy.orm import aliased from datagateway_api.common.database.filters import ( + DatabaseDistinctFieldFilter, DatabaseIncludeFilter as IncludeFilter, DatabaseWhereFilter as WhereFilter, ) @@ -278,7 +280,7 @@ def get_filtered_read_query_results(filter_handler, filters, query): filter_handler.apply_filters(query) results = query.get_all_results() if query.is_distinct_fields_query: - return _get_distinct_fields_as_dicts(results) + return _get_distinct_fields_as_dicts(filters, results) if query.include_related_entities: return _get_results_with_include(filters, results) return list(map(lambda x: x.to_dict(), results)) @@ -298,7 +300,7 @@ def _get_results_with_include(filters, results): return [x.to_nested_dict(query_filter.included_filters) for x in results] -def _get_distinct_fields_as_dicts(results): +def _get_distinct_fields_as_dicts(filters, results): """ Given a list of column results return a list of dictionaries where each column name is the key and the column value is the dictionary key value @@ -306,10 +308,16 @@ def _get_distinct_fields_as_dicts(results): :param results: A list of sql alchemy result objects :return: A list of dictionary representations of the sqlalchemy result objects """ + distinct_fields = [] + for query_filter in filters: + if type(query_filter) is DatabaseDistinctFieldFilter: + distinct_fields.extend(query_filter.fields) + dictionaries = [] for result in results: - dictionary = {k: getattr(result, k) for k in result.keys()} + dictionary = map_distinct_attributes_to_results(distinct_fields, result) dictionaries.append(dictionary) + return dictionaries From d3bda6ecd60fdebae24c485bc389453d3155ce8f Mon Sep 17 00:00:00 2001 From: Matthew Richards Date: Tue, 11 May 2021 13:49:40 +0000 Subject: [PATCH 06/20] #223: Finish detail on docstring --- datagateway_api/common/helpers.py | 5 ++--- 1 file changed, 2 insertions(+), 3 deletions(-) diff --git a/datagateway_api/common/helpers.py b/datagateway_api/common/helpers.py index 3012265e..607b9211 100644 --- a/datagateway_api/common/helpers.py +++ b/datagateway_api/common/helpers.py @@ -146,12 +146,11 @@ def map_distinct_attributes_to_results(distinct_attributes, query_result): :param query_result: Results fetched from a database query (backend independent due to the data structure of this parameter) :type query_result: :class:`tuple` or :class:`list` when a single attribute is - given from ICAT backend, TODO + given from ICAT backend, or :class:`sqlalchemy.engine.row.Row` when used on the + DB backend :return: Dictionary of attribute names paired with the results, ready to be returned to the user """ - log.debug(f"Query Result Type: {type(query_result)}") - result_dict = {} for attr_name, data in zip(distinct_attributes, query_result): # Splitting attribute names in case it's from a related entity From d21bd670a1931304c9bb87f3c7b413ac9bf5abe6 Mon Sep 17 00:00:00 2001 From: Matthew Richards Date: Tue, 11 May 2021 14:03:14 +0000 Subject: [PATCH 07/20] #223: Move tests for distinct attr mapping --- test/icat/test_query.py | 60 ------------------------------- test/test_map_distinct_attrs.py | 63 +++++++++++++++++++++++++++++++++ 2 files changed, 63 insertions(+), 60 deletions(-) create mode 100644 test/test_map_distinct_attrs.py diff --git a/test/icat/test_query.py b/test/icat/test_query.py index 869c6445..de399a18 100644 --- a/test/icat/test_query.py +++ b/test/icat/test_query.py @@ -345,66 +345,6 @@ def test_valid_get_distinct_attributes(self, icat_client): assert test_query.get_distinct_attributes() == ["summary", "name"] - @pytest.mark.parametrize( - "distinct_attrs, result, expected_output", - [ - pytest.param( - ["summary"], - ["Summary 1"], - {"summary": "Summary 1"}, - id="Single attribute", - ), - pytest.param( - ["startDate"], - ( - datetime( - year=2020, - month=1, - day=4, - hour=1, - minute=1, - second=1, - tzinfo=timezone.utc, - ), - ), - {"startDate": "2020-01-04 01:01:01+00:00"}, - id="Single date attribute", - ), - pytest.param( - ["summary", "title"], - ("Summary 1", "Title 1"), - {"summary": "Summary 1", "title": "Title 1"}, - id="Multiple attributes", - ), - pytest.param( - ["summary", "investigationUsers.role"], - ("Summary 1", "PI"), - {"summary": "Summary 1", "investigationUsers": {"role": "PI"}}, - id="Multiple attributes with related attribute", - ), - pytest.param( - ["summary", "investigationUsers.investigation.name"], - ("Summary 1", "Investigation Name 1"), - { - "summary": "Summary 1", - "investigationUsers": { - "investigation": {"name": "Investigation Name 1"}, - }, - }, - id="Multiple attributes with 2-level nested related attribute", - ), - ], - ) - def test_valid_map_distinct_attributes_to_results( - self, icat_client, distinct_attrs, result, expected_output, - ): - test_query = ICATQuery(icat_client, "Investigation") - test_output = test_query.map_distinct_attributes_to_results( - distinct_attrs, result, - ) - - assert test_output == expected_output - def test_include_fields_list_flatten(self, icat_client): included_field_set = { "investigationUsers.investigation.datasets", diff --git a/test/test_map_distinct_attrs.py b/test/test_map_distinct_attrs.py new file mode 100644 index 00000000..c151f9b6 --- /dev/null +++ b/test/test_map_distinct_attrs.py @@ -0,0 +1,63 @@ +from datagateway_api.common.helpers import map_distinct_attributes_to_results +from datetime import datetime, timezone + +import pytest + + +class TestMapDistinctAttrs: + @pytest.mark.parametrize( + "distinct_attrs, result, expected_output", + [ + pytest.param( + ["summary"], + ["Summary 1"], + {"summary": "Summary 1"}, + id="Single attribute", + ), + pytest.param( + ["startDate"], + ( + datetime( + year=2020, + month=1, + day=4, + hour=1, + minute=1, + second=1, + tzinfo=timezone.utc, + ), + ), + {"startDate": "2020-01-04 01:01:01+00:00"}, + id="Single date attribute", + ), + pytest.param( + ["summary", "title"], + ("Summary 1", "Title 1"), + {"summary": "Summary 1", "title": "Title 1"}, + id="Multiple attributes", + ), + pytest.param( + ["summary", "investigationUsers.role"], + ("Summary 1", "PI"), + {"summary": "Summary 1", "investigationUsers": {"role": "PI"}}, + id="Multiple attributes with related attribute", + ), + pytest.param( + ["summary", "investigationUsers.investigation.name"], + ("Summary 1", "Investigation Name 1"), + { + "summary": "Summary 1", + "investigationUsers": { + "investigation": {"name": "Investigation Name 1"}, + }, + }, + id="Multiple attributes with 2-level nested related attribute", + ), + ], + ) + def test_valid_map_distinct_attributes_to_results( + self, distinct_attrs, result, expected_output, + ): + test_output = map_distinct_attributes_to_results(distinct_attrs, result) + + assert test_output == expected_output From 4a799407b5f43de4a0f4ac1fe533594d649efe95 Mon Sep 17 00:00:00 2001 From: Matthew Richards Date: Tue, 11 May 2021 14:13:18 +0000 Subject: [PATCH 08/20] #223: Fix linting issues --- datagateway_api/common/database/filters.py | 1 - datagateway_api/common/database/helpers.py | 2 +- datagateway_api/common/helpers.py | 2 +- datagateway_api/common/icat/query.py | 2 +- test/icat/test_query.py | 2 +- test/test_map_distinct_attrs.py | 3 ++- 6 files changed, 6 insertions(+), 6 deletions(-) diff --git a/datagateway_api/common/database/filters.py b/datagateway_api/common/database/filters.py index ffce968f..c90c811e 100644 --- a/datagateway_api/common/database/filters.py +++ b/datagateway_api/common/database/filters.py @@ -2,7 +2,6 @@ from sqlalchemy import asc, desc -from datagateway_api.common.database import models from datagateway_api.common.exceptions import FilterError, MultipleIncludeError from datagateway_api.common.filters import ( DistinctFieldFilter, diff --git a/datagateway_api/common/database/helpers.py b/datagateway_api/common/database/helpers.py index 1a4ed8ee..19cb31f9 100644 --- a/datagateway_api/common/database/helpers.py +++ b/datagateway_api/common/database/helpers.py @@ -1,5 +1,4 @@ from abc import ABC, abstractmethod -from datagateway_api.common.helpers import map_distinct_attributes_to_results import datetime from functools import wraps import logging @@ -26,6 +25,7 @@ MissingRecordError, ) from datagateway_api.common.filter_order_handler import FilterOrderHandler +from datagateway_api.common.helpers import map_distinct_attributes_to_results log = logging.getLogger() diff --git a/datagateway_api/common/helpers.py b/datagateway_api/common/helpers.py index 607b9211..452be102 100644 --- a/datagateway_api/common/helpers.py +++ b/datagateway_api/common/helpers.py @@ -1,4 +1,3 @@ -from datagateway_api.common.date_handler import DateHandler from datetime import datetime from functools import wraps import json @@ -9,6 +8,7 @@ from sqlalchemy.exc import IntegrityError from datagateway_api.common.database import models +from datagateway_api.common.date_handler import DateHandler from datagateway_api.common.exceptions import ( ApiError, AuthenticationError, diff --git a/datagateway_api/common/icat/query.py b/datagateway_api/common/icat/query.py index 791714b5..c690c258 100644 --- a/datagateway_api/common/icat/query.py +++ b/datagateway_api/common/icat/query.py @@ -1,4 +1,3 @@ -from datagateway_api.common.helpers import map_distinct_attributes_to_results from datetime import datetime import logging @@ -8,6 +7,7 @@ from datagateway_api.common.date_handler import DateHandler from datagateway_api.common.exceptions import PythonICATError +from datagateway_api.common.helpers import map_distinct_attributes_to_results log = logging.getLogger() diff --git a/test/icat/test_query.py b/test/icat/test_query.py index de399a18..47df71f6 100644 --- a/test/icat/test_query.py +++ b/test/icat/test_query.py @@ -1,4 +1,4 @@ -from datetime import datetime, timezone +from datetime import datetime from icat.entity import Entity import pytest diff --git a/test/test_map_distinct_attrs.py b/test/test_map_distinct_attrs.py index c151f9b6..626345ab 100644 --- a/test/test_map_distinct_attrs.py +++ b/test/test_map_distinct_attrs.py @@ -1,8 +1,9 @@ -from datagateway_api.common.helpers import map_distinct_attributes_to_results from datetime import datetime, timezone import pytest +from datagateway_api.common.helpers import map_distinct_attributes_to_results + class TestMapDistinctAttrs: @pytest.mark.parametrize( From bc9635153bae681e89962fb37d71c14cf16056e7 Mon Sep 17 00:00:00 2001 From: Matthew Richards Date: Thu, 13 May 2021 11:44:05 +0000 Subject: [PATCH 09/20] #223: Add tests for DatabaseFilterUtilities --- test/db/test_database_filter_utilities.py | 104 ++++++++++++++++++++++ 1 file changed, 104 insertions(+) create mode 100644 test/db/test_database_filter_utilities.py diff --git a/test/db/test_database_filter_utilities.py b/test/db/test_database_filter_utilities.py new file mode 100644 index 00000000..9b1e74b7 --- /dev/null +++ b/test/db/test_database_filter_utilities.py @@ -0,0 +1,104 @@ +import pytest + +from datagateway_api.common.database.filters import DatabaseFilterUtilities +from datagateway_api.common.database.helpers import ReadQuery +from datagateway_api.common.exceptions import FilterError +from datagateway_api.common.helpers import get_entity_object_from_name + + +class TestDatabaseFilterUtilities: + @pytest.mark.parametrize( + "input_field, expected_fields", + [ + pytest.param("name", ("name", None, None), id="Unrelated field"), + pytest.param( + "facility.daysUntilRelease", + ("facility", "daysUntilRelease", None), + id="Related field matching ICAT schema name", + ), + pytest.param( + "FACILITY.daysUntilRelease", + ("FACILITY", "daysUntilRelease", None), + id="Related field matching database format (uppercase)", + ), + pytest.param( + "user.investigationUsers.role", + ("user", "investigationUsers", "role"), + id="Related related field (2 levels deep)", + ), + ], + ) + def test_valid_extract_filter_fields(self, input_field, expected_fields): + test_utility = DatabaseFilterUtilities() + test_utility._extract_filter_fields(input_field) + + assert test_utility.field == expected_fields[0] + assert test_utility.related_field == expected_fields[1] + assert test_utility.related_related_field == expected_fields[2] + + def test_invalid_extract_filter_fields(self): + test_utility = DatabaseFilterUtilities() + + with pytest.raises(ValueError): + test_utility._extract_filter_fields( + "user.investigationUsers.investigation.summary", + ) + + @pytest.mark.parametrize( + "input_field", + [ + pytest.param("name", id="No related fields"), + pytest.param("facility.daysUntilRelease", id="Related field"), + pytest.param( + "investigationUsers.user.fullName", id="Related related field", + ), + ], + ) + def test_valid_add_query_join( + self, flask_test_app_db, input_field, + ): + table = get_entity_object_from_name("Investigation") + + test_utility = DatabaseFilterUtilities() + test_utility._extract_filter_fields(input_field) + + expected_query = ReadQuery(table) + if test_utility.related_related_field: + expected_table = get_entity_object_from_name(test_utility.related_field) + + included_table = get_entity_object_from_name(test_utility.field) + expected_query.base_query = expected_query.base_query.join( + included_table, + ).join(expected_table) + elif test_utility.related_field: + expected_table = get_entity_object_from_name(test_utility.field) + + expected_query = ReadQuery(table) + expected_query.base_query = expected_query.base_query.join(expected_table) + else: + expected_table = table + + with ReadQuery(table) as test_query: + output_field = test_utility._add_query_join(test_query) + + # Check the JOIN has been applied + assert str(test_query.base_query) == str(expected_query.base_query) + + # Check the output is correct + field_name_to_fetch = input_field.split(".")[-1] + assert output_field == getattr(expected_table, field_name_to_fetch) + + def test_valid_get_field(self, flask_test_app_db): + table = get_entity_object_from_name("Investigation") + + test_utility = DatabaseFilterUtilities() + field = test_utility._get_field(table, "name") + + assert field == table.name + + def test_invalid_get_field(self, flask_test_app_db): + table = get_entity_object_from_name("Investigation") + + test_utility = DatabaseFilterUtilities() + with pytest.raises(FilterError): + test_utility._get_field(table, "unknown") From a2538ccba6810fbd23c6f2c9d1d09ea89ba8498b Mon Sep 17 00:00:00 2001 From: Matthew Richards Date: Fri, 14 May 2021 10:23:56 +0000 Subject: [PATCH 10/20] #223: Separate out `_add_query_join()` - This function used to have two purposes - the returning of the entity model/field has been moved to another function - Separating out these two jobs means the SQLAlchemy warning has been fixed/no longer appears --- datagateway_api/common/database/filters.py | 45 +++++++++++++++++----- 1 file changed, 36 insertions(+), 9 deletions(-) diff --git a/datagateway_api/common/database/filters.py b/datagateway_api/common/database/filters.py index c90c811e..aed8ba31 100644 --- a/datagateway_api/common/database/filters.py +++ b/datagateway_api/common/database/filters.py @@ -31,6 +31,7 @@ class DatabaseFilterUtilities: """ def __init__(self): + self.field = None self.related_field = None self.related_related_field = None @@ -43,6 +44,11 @@ def _extract_filter_fields(self, field): :raises ValueError: If the maximum related/included depth is exceeded """ + # Flushing fields in case they have been previously set + self.field = None + self.related_field = None + self.related_related_field = None + fields = field.split(".") related_depth = len(fields) @@ -62,14 +68,11 @@ def _extract_filter_fields(self, field): def _add_query_join(self, query): """ - Fetches the appropriate entity model based on the contents of `self.field` and - adds any required JOINs to the query if any related fields have been used in the + Adds any required JOINs to the query if any related fields have been used in the filter :param query: The query to have filters applied to :type query: :class:`datagateway_api.common.database.helpers.[QUERY]` - :return: Entity model of the field (usually the field relating to the endpoint - the request is coming from) """ if self.related_related_field: @@ -78,10 +81,25 @@ def _add_query_join(self, query): query.base_query = query.base_query.join(included_table).join( included_included_table, ) - field = self._get_field(included_included_table, self.related_related_field) elif self.related_field: included_table = get_entity_object_from_name(self.field) query.base_query = query.base_query.join(included_table) + + def _get_entity_model_for_filter(self, query): + """ + Fetches the appropriate entity model based on the contents of the instance + variables of this class + + :param query: The query to have filters applied to + :type query: :class:`datagateway_api.common.database.helpers.[QUERY]` + :return: Entity model of the field (usually the field relating to the endpoint + the request is coming from) + """ + if self.related_related_field: + included_included_table = get_entity_object_from_name(self.related_field) + field = self._get_field(included_included_table, self.related_related_field) + elif self.related_field: + included_table = get_entity_object_from_name(self.field) field = self._get_field(included_table, self.related_field) else: # No related fields @@ -105,7 +123,8 @@ def __init__(self, field, value, operation): self._extract_filter_fields(field) def apply_filter(self, query): - field = self._add_query_join(query) + self._add_query_join(query) + field = self._get_entity_model_for_filter(query) if self.operation == "eq": query.base_query = query.base_query.filter(field == self.value) @@ -140,15 +159,23 @@ def __init__(self, fields): def apply_filter(self, query): query.is_distinct_fields_query = True + try: distinct_fields = [] for field_name in self.fields: self._extract_filter_fields(field_name) - field = self._add_query_join(query) - distinct_fields.append(field) + distinct_fields.append(self._get_entity_model_for_filter(query)) + + # Base query must be set to a DISTINCT query before adding JOINs - if these + # actions are done in the opposite order, the JOINs will overwrite the + # SELECT multiple and effectively turn the query into a `SELECT *` + query.base_query = query.session.query(*distinct_fields).distinct() + + for field_name in self.fields: + self._extract_filter_fields(field_name) + self._add_query_join(query) except AttributeError: raise FilterError("Bad field requested") - query.base_query = query.session.query(*distinct_fields).distinct() class DatabaseOrderFilter(OrderFilter): From 3916a561ce7af57eac13eaeeec488364bdfbc760 Mon Sep 17 00:00:00 2001 From: Matthew Richards Date: Fri, 14 May 2021 11:06:18 +0000 Subject: [PATCH 11/20] #223: Add tests for _get_entity_model_for_filter() --- test/db/test_database_filter_utilities.py | 28 ++++++++++++++++++++++- 1 file changed, 27 insertions(+), 1 deletion(-) diff --git a/test/db/test_database_filter_utilities.py b/test/db/test_database_filter_utilities.py index 9b1e74b7..7a19630d 100644 --- a/test/db/test_database_filter_utilities.py +++ b/test/db/test_database_filter_utilities.py @@ -79,11 +79,37 @@ def test_valid_add_query_join( expected_table = table with ReadQuery(table) as test_query: - output_field = test_utility._add_query_join(test_query) + test_utility._add_query_join(test_query) # Check the JOIN has been applied assert str(test_query.base_query) == str(expected_query.base_query) + @pytest.mark.parametrize( + "input_field", + [ + pytest.param("name", id="No related fields"), + pytest.param("facility.daysUntilRelease", id="Related field"), + pytest.param( + "investigationUsers.user.fullName", id="Related related field", + ), + ], + ) + def test_valid_get_entity_model_for_filter(self, input_field): + table = get_entity_object_from_name("Investigation") + + test_utility = DatabaseFilterUtilities() + test_utility._extract_filter_fields(input_field) + + if test_utility.related_related_field: + expected_table = get_entity_object_from_name(test_utility.related_field) + elif test_utility.related_field: + expected_table = get_entity_object_from_name(test_utility.field) + else: + expected_table = table + + with ReadQuery(table) as test_query: + output_field = test_utility._get_entity_model_for_filter(test_query) + # Check the output is correct field_name_to_fetch = input_field.split(".")[-1] assert output_field == getattr(expected_table, field_name_to_fetch) From 0dfa7a384fe637bf3427ad40446bff196547d993 Mon Sep 17 00:00:00 2001 From: Matthew Richards Date: Fri, 14 May 2021 12:35:03 +0000 Subject: [PATCH 12/20] #223: Fix bug where single related distinct field was given - This fixes requests such as: `/investigations?distinct=["investigationtype.createTime"]` --- datagateway_api/common/database/filters.py | 15 ++++++++++++--- 1 file changed, 12 insertions(+), 3 deletions(-) diff --git a/datagateway_api/common/database/filters.py b/datagateway_api/common/database/filters.py index aed8ba31..a279971f 100644 --- a/datagateway_api/common/database/filters.py +++ b/datagateway_api/common/database/filters.py @@ -31,9 +31,16 @@ class DatabaseFilterUtilities: """ def __init__(self): + """ + The `distinct_join_flag` tracks if JOINs need to be added to the query - on a + distinct filter, if there's no unrelated fields (i.e. no fields with a + `related_depth` of 1), adding JOINs to the query (using `_add_query_join()`) + will result in a `sqlalchemy.exc.InvalidRequestError` + """ self.field = None self.related_field = None self.related_related_field = None + self.distinct_join_flag = False def _extract_filter_fields(self, field): """ @@ -56,6 +63,7 @@ def _extract_filter_fields(self, field): if related_depth == 1: self.field = fields[0] + self.distinct_join_flag = True elif related_depth == 2: self.field = fields[0] self.related_field = fields[1] @@ -171,9 +179,10 @@ def apply_filter(self, query): # SELECT multiple and effectively turn the query into a `SELECT *` query.base_query = query.session.query(*distinct_fields).distinct() - for field_name in self.fields: - self._extract_filter_fields(field_name) - self._add_query_join(query) + if self.distinct_join_flag: + for field_name in self.fields: + self._extract_filter_fields(field_name) + self._add_query_join(query) except AttributeError: raise FilterError("Bad field requested") From 2623b1838c59cb02faf2d21750182807799b12aa Mon Sep 17 00:00:00 2001 From: Matthew Richards Date: Fri, 14 May 2021 13:33:06 +0000 Subject: [PATCH 13/20] #223: Improve assertion on distinct test case --- test/db/endpoints/test_get_with_filters.py | 5 ++--- 1 file changed, 2 insertions(+), 3 deletions(-) diff --git a/test/db/endpoints/test_get_with_filters.py b/test/db/endpoints/test_get_with_filters.py index d5e8ee19..c201374b 100644 --- a/test/db/endpoints/test_get_with_filters.py +++ b/test/db/endpoints/test_get_with_filters.py @@ -28,7 +28,7 @@ def test_valid_no_results_get_with_filters( assert test_response.json == [] @pytest.mark.usefixtures("multiple_investigation_test_data_db") - def test_valid_get_with_filters_distinct( + def test_valid_get_with_filters_multiple_distinct( self, flask_test_app_db, valid_db_credentials_header, ): test_response = flask_test_app_db.get( @@ -41,8 +41,7 @@ def test_valid_get_with_filters_distinct( {"title": f"Title for DataGateway API Testing (DB) {i}"} for i in range(5) ] - for title in expected: - assert title in test_response.json + assert test_response.json == expected def test_limit_skip_merge_get_with_filters( self, From f261f7b893d791b6a312b4251b88c1481c919fd2 Mon Sep 17 00:00:00 2001 From: Matthew Richards Date: Fri, 14 May 2021 14:17:20 +0000 Subject: [PATCH 14/20] #223: Add tests for distinct filter with related entities --- test/db/endpoints/test_get_with_filters.py | 65 ++++++++++++++++++++++ 1 file changed, 65 insertions(+) diff --git a/test/db/endpoints/test_get_with_filters.py b/test/db/endpoints/test_get_with_filters.py index c201374b..6844cada 100644 --- a/test/db/endpoints/test_get_with_filters.py +++ b/test/db/endpoints/test_get_with_filters.py @@ -1,5 +1,8 @@ import pytest +from datagateway_api.common.constants import Constants +from datagateway_api.common.date_handler import DateHandler + class TestDBGetWithFilters: def test_valid_get_with_filters( @@ -43,6 +46,68 @@ def test_valid_get_with_filters_multiple_distinct( assert test_response.json == expected + @pytest.mark.parametrize( + "distinct_param, expected_response", + [ + pytest.param( + '"title"', + [{"title": "Title for DataGateway API Testing (DB) 0"}], + id="Single unrelated distinct field", + ), + pytest.param( + '["createTime", "investigationInstruments.createTime"]', + [ + { + "createTime": DateHandler.datetime_object_to_str( + Constants.TEST_MOD_CREATE_DATETIME, + ), + "investigationInstruments": { + "createTime": DateHandler.datetime_object_to_str( + Constants.TEST_MOD_CREATE_DATETIME, + ), + }, + }, + ], + id="List containing related distinct field", + ), + pytest.param( + '["createTime", "investigationInstruments.createTime", "facility.id"]', + [ + { + "createTime": DateHandler.datetime_object_to_str( + Constants.TEST_MOD_CREATE_DATETIME, + ), + "facility": {"id": 1}, + "investigationInstruments": { + "createTime": DateHandler.datetime_object_to_str( + Constants.TEST_MOD_CREATE_DATETIME, + ), + }, + }, + ], + id="Multiple related distinct fields", + ), + ], + ) + @pytest.mark.usefixtures("isis_specific_endpoint_data_db") + def test_valid_get_with_filters_related_distinct( + self, + flask_test_app_db, + valid_db_credentials_header, + distinct_param, + expected_response, + ): + test_response = flask_test_app_db.get( + '/investigations?where={"title": {"like": "Title for DataGateway API' + ' Testing (DB)"}}' + f"&distinct={distinct_param}", + headers=valid_db_credentials_header, + ) + + print(test_response.json) + + assert test_response.json == expected_response + def test_limit_skip_merge_get_with_filters( self, flask_test_app_db, From dc583a8155c0e2f20017cf996f6aeded94baedb1 Mon Sep 17 00:00:00 2001 From: Matthew Richards Date: Fri, 14 May 2021 14:48:09 +0000 Subject: [PATCH 15/20] #225: Fix timezone-related issues found by merging master - These issues were a result of writing new code that involved timezones before master got merged in, which had the original fixes for timezones on DB backend --- datagateway_api/common/helpers.py | 5 +++++ test/icat/test_query.py | 1 - test/test_map_distinct_attrs.py | 5 +++-- 3 files changed, 8 insertions(+), 3 deletions(-) diff --git a/datagateway_api/common/helpers.py b/datagateway_api/common/helpers.py index 452be102..88d8f178 100644 --- a/datagateway_api/common/helpers.py +++ b/datagateway_api/common/helpers.py @@ -3,6 +3,7 @@ import json import logging +from dateutil.tz.tz import tzlocal from flask import request from flask_restful import reqparse from sqlalchemy.exc import IntegrityError @@ -157,6 +158,10 @@ def map_distinct_attributes_to_results(distinct_attributes, query_result): split_attr_name = attr_name.split(".") if isinstance(data, datetime): + # Workaround for when this function is used on DB backend, where usually + # `_make_serializable()` would fix tzinfo + if data.tzinfo is None: + data = data.replace(tzinfo=tzlocal()) data = DateHandler.datetime_object_to_str(data) # Attribute name is from the 'origin' entity (i.e. not a related entity) diff --git a/test/icat/test_query.py b/test/icat/test_query.py index a4dd9fac..47df71f6 100644 --- a/test/icat/test_query.py +++ b/test/icat/test_query.py @@ -1,6 +1,5 @@ from datetime import datetime -from dateutil.tz import tzlocal from icat.entity import Entity import pytest diff --git a/test/test_map_distinct_attrs.py b/test/test_map_distinct_attrs.py index 626345ab..615bc3f5 100644 --- a/test/test_map_distinct_attrs.py +++ b/test/test_map_distinct_attrs.py @@ -1,5 +1,6 @@ -from datetime import datetime, timezone +from datetime import datetime +from dateutil.tz import tzlocal import pytest from datagateway_api.common.helpers import map_distinct_attributes_to_results @@ -25,7 +26,7 @@ class TestMapDistinctAttrs: hour=1, minute=1, second=1, - tzinfo=timezone.utc, + tzinfo=tzlocal(), ), ), {"startDate": "2020-01-04 01:01:01+00:00"}, From c9aa20b1f92cb0471ebacc138e2fa601d525d92e Mon Sep 17 00:00:00 2001 From: Matthew Richards Date: Mon, 17 May 2021 09:24:01 +0000 Subject: [PATCH 16/20] #223: Remove commented code - This was the best solution I could find, particularly as the two init's have different method signatures --- datagateway_api/common/database/filters.py | 3 --- 1 file changed, 3 deletions(-) diff --git a/datagateway_api/common/database/filters.py b/datagateway_api/common/database/filters.py index a279971f..d2944ecd 100644 --- a/datagateway_api/common/database/filters.py +++ b/datagateway_api/common/database/filters.py @@ -124,7 +124,6 @@ def _get_field(self, table, field): class DatabaseWhereFilter(WhereFilter, DatabaseFilterUtilities): def __init__(self, field, value, operation): - # TODO - Apply any 'pythonic' solution here too WhereFilter.__init__(self, field, value, operation) DatabaseFilterUtilities.__init__(self) @@ -160,8 +159,6 @@ def apply_filter(self, query): class DatabaseDistinctFieldFilter(DistinctFieldFilter, DatabaseFilterUtilities): def __init__(self, fields): - # TODO - what's the Pythonic solution here? - # super().__init__(fields) DistinctFieldFilter.__init__(self, fields) DatabaseFilterUtilities.__init__(self) From 9021fc865872586441fe710b29f7cfbd9ca87522 Mon Sep 17 00:00:00 2001 From: Matthew Richards Date: Mon, 17 May 2021 10:45:52 +0000 Subject: [PATCH 17/20] #223: Fix bug with related distinct field with no unrelated fields - This fix replaces the distinct flag and as a result, that flag has been removed. Quite happy with this fix, much better than what I achieved on Friday --- datagateway_api/common/database/filters.py | 14 ++++++++------ 1 file changed, 8 insertions(+), 6 deletions(-) diff --git a/datagateway_api/common/database/filters.py b/datagateway_api/common/database/filters.py index d2944ecd..c3c72142 100644 --- a/datagateway_api/common/database/filters.py +++ b/datagateway_api/common/database/filters.py @@ -40,7 +40,6 @@ def __init__(self): self.field = None self.related_field = None self.related_related_field = None - self.distinct_join_flag = False def _extract_filter_fields(self, field): """ @@ -174,12 +173,15 @@ def apply_filter(self, query): # Base query must be set to a DISTINCT query before adding JOINs - if these # actions are done in the opposite order, the JOINs will overwrite the # SELECT multiple and effectively turn the query into a `SELECT *` - query.base_query = query.session.query(*distinct_fields).distinct() + query.base_query = ( + query.session.query(*distinct_fields) + .select_from(query.table) + .distinct() + ) - if self.distinct_join_flag: - for field_name in self.fields: - self._extract_filter_fields(field_name) - self._add_query_join(query) + for field_name in self.fields: + self._extract_filter_fields(field_name) + self._add_query_join(query) except AttributeError: raise FilterError("Bad field requested") From f16cb15a4360e6fb818db8409508d6f9352127e0 Mon Sep 17 00:00:00 2001 From: Matthew Richards Date: Mon, 17 May 2021 10:47:17 +0000 Subject: [PATCH 18/20] #223: Add test cases for the bug fixed in previous commit --- test/db/endpoints/test_get_with_filters.py | 31 ++++++++++++++++++++-- 1 file changed, 29 insertions(+), 2 deletions(-) diff --git a/test/db/endpoints/test_get_with_filters.py b/test/db/endpoints/test_get_with_filters.py index 6844cada..e43e6b76 100644 --- a/test/db/endpoints/test_get_with_filters.py +++ b/test/db/endpoints/test_get_with_filters.py @@ -54,6 +54,19 @@ def test_valid_get_with_filters_multiple_distinct( [{"title": "Title for DataGateway API Testing (DB) 0"}], id="Single unrelated distinct field", ), + pytest.param( + '"investigationInstruments.createTime"', + [ + { + "investigationInstruments": { + "createTime": DateHandler.datetime_object_to_str( + Constants.TEST_MOD_CREATE_DATETIME, + ), + }, + }, + ], + id="Single related distinct field", + ), pytest.param( '["createTime", "investigationInstruments.createTime"]', [ @@ -68,7 +81,21 @@ def test_valid_get_with_filters_multiple_distinct( }, }, ], - id="List containing related distinct field", + id="Single related distinct field with unrelated field", + ), + pytest.param( + '["investigationInstruments.createTime", "facility.id"]', + [ + { + "facility": {"id": 1}, + "investigationInstruments": { + "createTime": DateHandler.datetime_object_to_str( + Constants.TEST_MOD_CREATE_DATETIME, + ), + }, + }, + ], + id="Multiple related distinct fields", ), pytest.param( '["createTime", "investigationInstruments.createTime", "facility.id"]', @@ -85,7 +112,7 @@ def test_valid_get_with_filters_multiple_distinct( }, }, ], - id="Multiple related distinct fields", + id="Multiple related distinct fields with unrelated field", ), ], ) From 0a202d3d739aceac0e07e3cc06533753bc9da5ed Mon Sep 17 00:00:00 2001 From: Matthew Richards Date: Mon, 17 May 2021 10:59:05 +0000 Subject: [PATCH 19/20] #223: Remove underscore prefixes - `_get_field()` is the only function in that class that's used internally, so I've remvoed the underscores from the other functions as I'm not sure I've used them correctly --- datagateway_api/common/database/filters.py | 22 +++++++++++----------- test/db/test_database_filter_utilities.py | 12 ++++++------ 2 files changed, 17 insertions(+), 17 deletions(-) diff --git a/datagateway_api/common/database/filters.py b/datagateway_api/common/database/filters.py index c3c72142..7b773555 100644 --- a/datagateway_api/common/database/filters.py +++ b/datagateway_api/common/database/filters.py @@ -34,14 +34,14 @@ def __init__(self): """ The `distinct_join_flag` tracks if JOINs need to be added to the query - on a distinct filter, if there's no unrelated fields (i.e. no fields with a - `related_depth` of 1), adding JOINs to the query (using `_add_query_join()`) + `related_depth` of 1), adding JOINs to the query (using `add_query_join()`) will result in a `sqlalchemy.exc.InvalidRequestError` """ self.field = None self.related_field = None self.related_related_field = None - def _extract_filter_fields(self, field): + def extract_filter_fields(self, field): """ Extract the related fields names and put them into separate variables @@ -73,7 +73,7 @@ def _extract_filter_fields(self, field): else: raise ValueError(f"Maximum related depth exceeded. {field}'s depth > 3") - def _add_query_join(self, query): + def add_query_join(self, query): """ Adds any required JOINs to the query if any related fields have been used in the filter @@ -92,7 +92,7 @@ def _add_query_join(self, query): included_table = get_entity_object_from_name(self.field) query.base_query = query.base_query.join(included_table) - def _get_entity_model_for_filter(self, query): + def get_entity_model_for_filter(self, query): """ Fetches the appropriate entity model based on the contents of the instance variables of this class @@ -126,11 +126,11 @@ def __init__(self, field, value, operation): WhereFilter.__init__(self, field, value, operation) DatabaseFilterUtilities.__init__(self) - self._extract_filter_fields(field) + self.extract_filter_fields(field) def apply_filter(self, query): - self._add_query_join(query) - field = self._get_entity_model_for_filter(query) + self.add_query_join(query) + field = self.get_entity_model_for_filter(query) if self.operation == "eq": query.base_query = query.base_query.filter(field == self.value) @@ -167,8 +167,8 @@ def apply_filter(self, query): try: distinct_fields = [] for field_name in self.fields: - self._extract_filter_fields(field_name) - distinct_fields.append(self._get_entity_model_for_filter(query)) + self.extract_filter_fields(field_name) + distinct_fields.append(self.get_entity_model_for_filter(query)) # Base query must be set to a DISTINCT query before adding JOINs - if these # actions are done in the opposite order, the JOINs will overwrite the @@ -180,8 +180,8 @@ def apply_filter(self, query): ) for field_name in self.fields: - self._extract_filter_fields(field_name) - self._add_query_join(query) + self.extract_filter_fields(field_name) + self.add_query_join(query) except AttributeError: raise FilterError("Bad field requested") diff --git a/test/db/test_database_filter_utilities.py b/test/db/test_database_filter_utilities.py index 7a19630d..8321bbf8 100644 --- a/test/db/test_database_filter_utilities.py +++ b/test/db/test_database_filter_utilities.py @@ -30,7 +30,7 @@ class TestDatabaseFilterUtilities: ) def test_valid_extract_filter_fields(self, input_field, expected_fields): test_utility = DatabaseFilterUtilities() - test_utility._extract_filter_fields(input_field) + test_utility.extract_filter_fields(input_field) assert test_utility.field == expected_fields[0] assert test_utility.related_field == expected_fields[1] @@ -40,7 +40,7 @@ def test_invalid_extract_filter_fields(self): test_utility = DatabaseFilterUtilities() with pytest.raises(ValueError): - test_utility._extract_filter_fields( + test_utility.extract_filter_fields( "user.investigationUsers.investigation.summary", ) @@ -60,7 +60,7 @@ def test_valid_add_query_join( table = get_entity_object_from_name("Investigation") test_utility = DatabaseFilterUtilities() - test_utility._extract_filter_fields(input_field) + test_utility.extract_filter_fields(input_field) expected_query = ReadQuery(table) if test_utility.related_related_field: @@ -79,7 +79,7 @@ def test_valid_add_query_join( expected_table = table with ReadQuery(table) as test_query: - test_utility._add_query_join(test_query) + test_utility.add_query_join(test_query) # Check the JOIN has been applied assert str(test_query.base_query) == str(expected_query.base_query) @@ -98,7 +98,7 @@ def test_valid_get_entity_model_for_filter(self, input_field): table = get_entity_object_from_name("Investigation") test_utility = DatabaseFilterUtilities() - test_utility._extract_filter_fields(input_field) + test_utility.extract_filter_fields(input_field) if test_utility.related_related_field: expected_table = get_entity_object_from_name(test_utility.related_field) @@ -108,7 +108,7 @@ def test_valid_get_entity_model_for_filter(self, input_field): expected_table = table with ReadQuery(table) as test_query: - output_field = test_utility._get_entity_model_for_filter(test_query) + output_field = test_utility.get_entity_model_for_filter(test_query) # Check the output is correct field_name_to_fetch = input_field.split(".")[-1] From da118af4e7040dba0a3ac457c0cdec12a99a550f Mon Sep 17 00:00:00 2001 From: Matthew Richards Date: Mon, 17 May 2021 11:09:32 +0000 Subject: [PATCH 20/20] #223: Remove irrelevant docstring --- datagateway_api/common/database/filters.py | 6 ------ 1 file changed, 6 deletions(-) diff --git a/datagateway_api/common/database/filters.py b/datagateway_api/common/database/filters.py index 7b773555..69d982b6 100644 --- a/datagateway_api/common/database/filters.py +++ b/datagateway_api/common/database/filters.py @@ -31,12 +31,6 @@ class DatabaseFilterUtilities: """ def __init__(self): - """ - The `distinct_join_flag` tracks if JOINs need to be added to the query - on a - distinct filter, if there's no unrelated fields (i.e. no fields with a - `related_depth` of 1), adding JOINs to the query (using `add_query_join()`) - will result in a `sqlalchemy.exc.InvalidRequestError` - """ self.field = None self.related_field = None self.related_related_field = None