diff --git a/datagateway_api/src/datagateway_api/filter_order_handler.py b/datagateway_api/src/common/filter_order_handler.py similarity index 55% rename from datagateway_api/src/datagateway_api/filter_order_handler.py rename to datagateway_api/src/common/filter_order_handler.py index 7c4615e4..093a9192 100644 --- a/datagateway_api/src/datagateway_api/filter_order_handler.py +++ b/datagateway_api/src/common/filter_order_handler.py @@ -1,10 +1,13 @@ import logging from datagateway_api.src.datagateway_api.icat.filters import ( + PythonICATIncludeFilter, PythonICATLimitFilter, PythonICATOrderFilter, PythonICATSkipFilter, ) +from datagateway_api.src.search_api.filters import SearchAPIIncludeFilter +from datagateway_api.src.search_api.panosc_mappings import mappings log = logging.getLogger() @@ -44,6 +47,58 @@ def apply_filters(self, query): for query_filter in self.filters: query_filter.apply_filter(query) + def add_icat_relations_for_non_related_fields_of_panosc_related_entities( + self, panosc_entity_name, + ): + """ + When there are Search API included filters, get the ICAT relations (if any) for + the non-related fields of all the entities in the relations. Once retrieved, + add them to the `included_filters` list of a `PythonICATIncludeFilter` object + that may already exist in `self.filters`. If such filter does not exist in + `self.filters` then create a new `PythonICATIncludeFilter` object, passing the + ICAT relations to it. Doing this will ensure that ICAT related entities that + map to non-related PaNOSC fields are included in the call made to ICAT. + + A `PythonICATIncludeFilter` object can exist in `self.filters` when one is + created and added in the `get_search` method. This is done when the the PaNOSC + entity for which search is been retrieved has non-related fields that have + ICAT relations. For example, the Document entity has non-related fields that + map to the `keywords` and `type` ICAT entities that are related to the + `investigation` entity. + + :param panosc_entity_name: A PaNOSC entity name e.g. "Dataset" + :type panosc_entity_name: :class:`str` + """ + + python_icat_include_filter = None + icat_relations = [] + for filter_ in self.filters: + if type(filter_) == PythonICATIncludeFilter: + # Using `type` as `isinstance` would return `True` for any class that + # inherits `PythonICATIncludeFilter` e.g. `SearchAPIIncludeFilter`.` + python_icat_include_filter = filter_ + elif isinstance(filter_, SearchAPIIncludeFilter): + included_filters = filter_.included_filters + for included_filter in included_filters: + icat_relations.extend( + mappings.get_icat_relations_for_non_related_fields_of_panosc_relation( # noqa: B950 + panosc_entity_name, included_filter, + ), + ) + + if icat_relations: + log.info( + "Including ICAT relations of non-related fields of related PaNOSC " + "entities", + ) + # Remove any duplicate ICAT relations + icat_relations = list(dict.fromkeys(icat_relations)) + if python_icat_include_filter: + python_icat_include_filter.included_filters.extend(icat_relations) + else: + python_icat_include_filter = PythonICATIncludeFilter(icat_relations) + self.filters.append(python_icat_include_filter) + def merge_python_icat_limit_skip_filters(self): """ When there are both limit and skip filters in a request, merge them into the diff --git a/datagateway_api/src/datagateway_api/database/helpers.py b/datagateway_api/src/datagateway_api/database/helpers.py index f6affd85..0c473b28 100644 --- a/datagateway_api/src/datagateway_api/database/helpers.py +++ b/datagateway_api/src/datagateway_api/database/helpers.py @@ -11,6 +11,7 @@ BadRequestError, MissingRecordError, ) +from datagateway_api.src.common.filter_order_handler import FilterOrderHandler from datagateway_api.src.common.helpers import map_distinct_attributes_to_results from datagateway_api.src.datagateway_api.database.filters import ( DatabaseDistinctFieldFilter, @@ -25,7 +26,6 @@ INVESTIGATIONINSTRUMENT, SESSION, ) -from datagateway_api.src.datagateway_api.filter_order_handler import FilterOrderHandler log = logging.getLogger() diff --git a/datagateway_api/src/datagateway_api/icat/helpers.py b/datagateway_api/src/datagateway_api/icat/helpers.py index 2df47078..55b79c77 100644 --- a/datagateway_api/src/datagateway_api/icat/helpers.py +++ b/datagateway_api/src/datagateway_api/icat/helpers.py @@ -21,7 +21,7 @@ MissingRecordError, PythonICATError, ) -from datagateway_api.src.datagateway_api.filter_order_handler import FilterOrderHandler +from datagateway_api.src.common.filter_order_handler import FilterOrderHandler from datagateway_api.src.datagateway_api.icat.filters import ( PythonICATLimitFilter, PythonICATWhereFilter, diff --git a/datagateway_api/src/search_api/helpers.py b/datagateway_api/src/search_api/helpers.py index cf67833d..8a6e033e 100644 --- a/datagateway_api/src/search_api/helpers.py +++ b/datagateway_api/src/search_api/helpers.py @@ -1,6 +1,8 @@ import logging -from datagateway_api.src.datagateway_api.filter_order_handler import FilterOrderHandler +from datagateway_api.src.common.filter_order_handler import FilterOrderHandler +from datagateway_api.src.datagateway_api.icat.filters import PythonICATIncludeFilter +from datagateway_api.src.search_api.panosc_mappings import mappings from datagateway_api.src.search_api.query import SearchAPIQuery from datagateway_api.src.search_api.session_handler import ( client_manager, @@ -15,6 +17,14 @@ def get_search(endpoint_name, entity_name, filters): log.debug("Entity Name: %s, Filters: %s", entity_name, filters) + icat_relations = mappings.get_icat_relations_for_panosc_non_related_fields( + entity_name, + ) + # Remove any duplicate ICAT relations + icat_relations = list(dict.fromkeys(icat_relations)) + if icat_relations: + filters.append(PythonICATIncludeFilter(icat_relations)) + query = SearchAPIQuery(entity_name) filter_handler = FilterOrderHandler() diff --git a/datagateway_api/src/search_api/panosc_mappings.py b/datagateway_api/src/search_api/panosc_mappings.py index 0af7a5ed..695c0245 100644 --- a/datagateway_api/src/search_api/panosc_mappings.py +++ b/datagateway_api/src/search_api/panosc_mappings.py @@ -98,5 +98,108 @@ def get_panosc_related_entity_name( return panosc_related_entity_name + def get_panosc_non_related_field_names(self, panosc_entity_name): + """ + This function retrieves the names of the non related fields of a given PaNOSC + entity. + + :param panosc_entity_name: A PaNOSC entity name e.g. "Dataset" + :type panosc_entity_name: :class:`str` + :return: List containing the names of the non related fields of the given + PaNOSC entity + :raises FilterError: If mappings for the given entity name cannot be found + """ + try: + entity_mappings = self.mappings[panosc_entity_name] + except KeyError: + raise FilterError( + f"Cannot find mappings for {[panosc_entity_name]} PaNOSC entity", + ) + + non_related_field_names = [] + for mapping_key, mapping_value in entity_mappings.items(): + # The mappings for the non-related fields are of type `str` and sometimes + # `list' whereas for the related fields, they are of type `dict`. + if mapping_key != "base_icat_entity" and ( + isinstance(mapping_value, str) or isinstance(mapping_value, list) + ): + non_related_field_names.append(mapping_key) + + return non_related_field_names + + def get_icat_relations_for_panosc_non_related_fields(self, panosc_entity_name): + """ + This function retrieves the ICAT relations for the non related fields of a + given PaNOSC entity. + + :param panosc_entity_name: A PaNOSC entity name e.g. "Dataset" + :type panosc_entity_name: :class:`str` + :return: List containing the ICAT relations for the non related fields of the + given PaNOSC entity + """ + icat_relations = [] + + field_names = self.get_panosc_non_related_field_names(panosc_entity_name) + for field_name in field_names: + _, icat_mapping = self.get_icat_mapping(panosc_entity_name, field_name) + + if not isinstance(icat_mapping, list): + icat_mapping = [icat_mapping] + + for mapping in icat_mapping: + split_mapping = mapping.split(".") + if len(split_mapping) > 1: + # Remove the last split element because it is an ICAT + # field name and is not therefore part of the relation + split_mapping = split_mapping[:-1] + split_mapping = ".".join(split_mapping) + icat_relations.append(split_mapping) + + return icat_relations + + def get_icat_relations_for_non_related_fields_of_panosc_relation( + self, panosc_entity_name, entity_relation, + ): + """ + THis function retrieves the ICAT relations for the non related fields of all the + PaNOSC entities that form a given PaNOSC entity relation which is applied to a + given PaNOSC entity. Relations can be non-nested or nested. Those that are + nested are represented in a dotted format e.g. "documents.members.person". When + a given relation is nested, this function retrieves the ICAT relations for the + first PaNOSC entity and then recursively calls itself until the ICAT relations + for the last PaNOSC entity in the relation are retrieved. + + :param panosc_entity_name: A PaNOSC entity name e.g. "Dataset" to which the + PaNOSC entity relation is applied + :type panosc_entity_name: :class:`str` + :param panosc_entity_name: A PaNOSC entity relation e.g. "documents" or + "documents.members.person" if nested + :type panosc_entity_name: :class:`str` + :return: List containing the ICAT relations for the non related fields of all + the PaNOSC entitities that form the given PaNOSC entity relation + """ + icat_relations = [] + + split_entity_relation = entity_relation.split(".") + related_entity_name, icat_field_name = self.get_icat_mapping( + panosc_entity_name, split_entity_relation[0], + ) + relations = self.get_icat_relations_for_panosc_non_related_fields( + related_entity_name, + ) + icat_relations.extend(relations) + + if len(split_entity_relation) > 1: + entity_relation = ".".join(split_entity_relation[1:]) + relations = self.get_icat_relations_for_non_related_fields_of_panosc_relation( # noqa: B950 + related_entity_name, entity_relation, + ) + icat_relations.extend(relations) + + for i, icat_relation in enumerate(icat_relations): + icat_relations[i] = f"{icat_field_name}.{icat_relation}" + + return icat_relations + mappings = PaNOSCMappings() diff --git a/test/conftest.py b/test/conftest.py index 83b2ab1a..9cc80041 100644 --- a/test/conftest.py +++ b/test/conftest.py @@ -3,13 +3,15 @@ from unittest.mock import mock_open, patch from flask import Flask +from icat.client import Client +from icat.query import Query import pytest from datagateway_api.src.api_start_utils import ( create_api_endpoints, create_app_infrastructure, ) -from datagateway_api.src.common.config import APIConfig +from datagateway_api.src.common.config import APIConfig, Config from datagateway_api.src.datagateway_api.database.helpers import ( delete_row_by_id, insert_row_into_table, @@ -17,6 +19,23 @@ from datagateway_api.src.datagateway_api.database.models import SESSION +@pytest.fixture(scope="package") +def icat_client(): + client = Client( + Config.config.datagateway_api.icat_url, + checkCert=Config.config.datagateway_api.icat_check_cert, + ) + client.login( + Config.config.test_mechanism, Config.config.test_user_credentials.dict(), + ) + return client + + +@pytest.fixture() +def icat_query(icat_client): + return Query(icat_client, "Investigation") + + @pytest.fixture() def bad_credentials_header(): return {"Authorization": "Bearer Invalid"} diff --git a/test/datagateway_api/icat/conftest.py b/test/datagateway_api/icat/conftest.py index c7ca8020..388f41aa 100644 --- a/test/datagateway_api/icat/conftest.py +++ b/test/datagateway_api/icat/conftest.py @@ -3,9 +3,7 @@ from dateutil.tz import tzlocal from flask import Flask -from icat.client import Client from icat.exception import ICATNoObjectError -from icat.query import Query import pytest from datagateway_api.src.api_start_utils import ( @@ -17,28 +15,11 @@ from test.datagateway_api.icat.test_query import prepare_icat_data_for_assertion -@pytest.fixture(scope="package") -def icat_client(): - client = Client( - Config.config.datagateway_api.icat_url, - checkCert=Config.config.datagateway_api.icat_check_cert, - ) - client.login( - Config.config.test_mechanism, Config.config.test_user_credentials.dict(), - ) - return client - - @pytest.fixture() def valid_icat_credentials_header(icat_client): return {"Authorization": f"Bearer {icat_client.sessionId}"} -@pytest.fixture() -def icat_query(icat_client): - return Query(icat_client, "Investigation") - - def create_investigation_test_data(client, num_entities=1): test_data = [] diff --git a/test/datagateway_api/icat/filters/test_limit_filter.py b/test/datagateway_api/icat/filters/test_limit_filter.py index b5cfbf8c..8921ac37 100644 --- a/test/datagateway_api/icat/filters/test_limit_filter.py +++ b/test/datagateway_api/icat/filters/test_limit_filter.py @@ -3,7 +3,7 @@ import pytest from datagateway_api.src.common.exceptions import FilterError -from datagateway_api.src.datagateway_api.filter_order_handler import FilterOrderHandler +from datagateway_api.src.common.filter_order_handler import FilterOrderHandler from datagateway_api.src.datagateway_api.icat.filters import ( icat_set_limit, PythonICATLimitFilter, diff --git a/test/datagateway_api/icat/filters/test_order_filter.py b/test/datagateway_api/icat/filters/test_order_filter.py index 0ab03681..406fd37b 100644 --- a/test/datagateway_api/icat/filters/test_order_filter.py +++ b/test/datagateway_api/icat/filters/test_order_filter.py @@ -2,7 +2,7 @@ from typing_extensions import OrderedDict from datagateway_api.src.common.exceptions import FilterError -from datagateway_api.src.datagateway_api.filter_order_handler import FilterOrderHandler +from datagateway_api.src.common.filter_order_handler import FilterOrderHandler from datagateway_api.src.datagateway_api.icat.filters import PythonICATOrderFilter diff --git a/test/datagateway_api/icat/filters/test_where_filter.py b/test/datagateway_api/icat/filters/test_where_filter.py index 254d8098..0210e845 100644 --- a/test/datagateway_api/icat/filters/test_where_filter.py +++ b/test/datagateway_api/icat/filters/test_where_filter.py @@ -1,7 +1,7 @@ import pytest from datagateway_api.src.common.exceptions import BadRequestError, FilterError -from datagateway_api.src.datagateway_api.filter_order_handler import FilterOrderHandler +from datagateway_api.src.common.filter_order_handler import FilterOrderHandler from datagateway_api.src.datagateway_api.icat.filters import PythonICATWhereFilter diff --git a/test/datagateway_api/icat/test_filter_order_handler.py b/test/datagateway_api/icat/test_filter_order_handler.py deleted file mode 100644 index 6e457741..00000000 --- a/test/datagateway_api/icat/test_filter_order_handler.py +++ /dev/null @@ -1,72 +0,0 @@ -import pytest - -from datagateway_api.src.datagateway_api.filter_order_handler import FilterOrderHandler -from datagateway_api.src.datagateway_api.icat.filters import ( - PythonICATLimitFilter, - PythonICATWhereFilter, -) - - -class TestFilterOrderHandler: - """ - `merge_python_icat_limit_skip_filters` and`clear_python_icat_order_filters()` are - tested while testing the ICAT backend filters, so tests of these functions won't be - found here - """ - - def test_add_filter(self, icat_query): - test_handler = FilterOrderHandler() - test_filter = PythonICATWhereFilter("id", 2, "eq") - - test_handler.add_filter(test_filter) - - assert test_handler.filters == [test_filter] - - def test_add_filters(self): - test_handler = FilterOrderHandler() - id_filter = PythonICATWhereFilter("id", 2, "eq") - name_filter = PythonICATWhereFilter("name", "New Name", "like") - filter_list = [id_filter, name_filter] - - test_handler.add_filters(filter_list) - - assert test_handler.filters == filter_list - - def test_remove_filter(self): - test_filter = PythonICATWhereFilter("id", 2, "eq") - - test_handler = FilterOrderHandler() - test_handler.add_filter(test_filter) - test_handler.remove_filter(test_filter) - - assert test_handler.filters == [] - - def test_remove_not_added_filter(self): - test_handler = FilterOrderHandler() - test_filter = PythonICATWhereFilter("id", 2, "eq") - - with pytest.raises(ValueError): - test_handler.remove_filter(test_filter) - - def test_sort_filters(self): - limit_filter = PythonICATLimitFilter(10) - where_filter = PythonICATWhereFilter("id", 2, "eq") - - test_handler = FilterOrderHandler() - test_handler.add_filters([limit_filter, where_filter]) - test_handler.sort_filters() - - assert test_handler.filters == [where_filter, limit_filter] - - def test_apply_filters(self, icat_query): - where_filter = PythonICATWhereFilter("id", 2, "eq") - limit_filter = PythonICATLimitFilter(10) - - test_handler = FilterOrderHandler() - test_handler.add_filters([where_filter, limit_filter]) - test_handler.apply_filters(icat_query) - - assert icat_query.conditions == {"id": ["%s = '2'"]} and icat_query.limit == ( - 0, - 10, - ) diff --git a/test/search_api/filters/test_search_api_include_filter.py b/test/search_api/filters/test_search_api_include_filter.py index 241295f4..e0ba2e69 100644 --- a/test/search_api/filters/test_search_api_include_filter.py +++ b/test/search_api/filters/test_search_api_include_filter.py @@ -1,6 +1,6 @@ import pytest -from datagateway_api.src.datagateway_api.filter_order_handler import FilterOrderHandler +from datagateway_api.src.common.filter_order_handler import FilterOrderHandler from datagateway_api.src.search_api.filters import SearchAPIIncludeFilter from datagateway_api.src.search_api.query import SearchAPIQuery diff --git a/test/search_api/filters/test_search_api_where_filter.py b/test/search_api/filters/test_search_api_where_filter.py index a1fdd536..7284a6cc 100644 --- a/test/search_api/filters/test_search_api_where_filter.py +++ b/test/search_api/filters/test_search_api_where_filter.py @@ -1,6 +1,6 @@ import pytest -from datagateway_api.src.datagateway_api.filter_order_handler import FilterOrderHandler +from datagateway_api.src.common.filter_order_handler import FilterOrderHandler from datagateway_api.src.search_api.filters import SearchAPIWhereFilter from datagateway_api.src.search_api.nested_where_filters import NestedWhereFilters from datagateway_api.src.search_api.query import SearchAPIQuery diff --git a/test/search_api/test_panosc_mappings.py b/test/search_api/test_panosc_mappings.py index 4d14054e..9a521de4 100644 --- a/test/search_api/test_panosc_mappings.py +++ b/test/search_api/test_panosc_mappings.py @@ -1,6 +1,6 @@ import pytest -from datagateway_api.src.common.exceptions import SearchAPIError +from datagateway_api.src.common.exceptions import FilterError, SearchAPIError from datagateway_api.src.search_api.panosc_mappings import PaNOSCMappings @@ -36,3 +36,145 @@ def test_invalid_get_panosc_related_entity_name(self, test_panosc_mappings): test_panosc_mappings.get_panosc_related_entity_name( "UnknownField", "unknownField", ) + + @pytest.mark.parametrize( + "test_panosc_entity_name, expected_non_related_field_names", + [ + pytest.param( + "Affiliation", + ["id", "name", "address", "city", "country"], + id="Affiliation", + ), + pytest.param( + "Dataset", + ["pid", "title", "isPublic", "creationDate", "size"], + id="Dataset", + ), + pytest.param( + "Document", + [ + "pid", + "isPublic", + "type", + "title", + "summary", + "doi", + "startDate", + "endDate", + "releaseDate", + "license", + "keywords", + ], + id="Document", + ), + pytest.param("File", ["id", "name", "path", "size"], id="File"), + pytest.param("Instrument", ["pid", "name", "facility"], id="Instrument"), + pytest.param("Member", ["id", "role"], id="Member"), + pytest.param("Parameter", ["id", "name", "value", "unit"], id="Parameter"), + pytest.param( + "Person", + ["id", "fullName", "orcid", "researcherId", "firstName", "lastName"], + id="Person", + ), + pytest.param("Sample", ["name", "pid", "description"], id="Sample"), + pytest.param("Technique", ["pid", "name"], id="Technique"), + ], + ) + def test_valid_get_panosc_non_related_field_names( + self, + test_panosc_mappings, + test_panosc_entity_name, + expected_non_related_field_names, + ): + non_related_field_names = test_panosc_mappings.get_panosc_non_related_field_names( # noqa: B950 + test_panosc_entity_name, + ) + assert non_related_field_names == expected_non_related_field_names + + def test_invalid_get_panosc_non_related_field_names( + self, test_panosc_mappings, + ): + with pytest.raises(FilterError): + test_panosc_mappings.get_panosc_non_related_field_names("UnknownEntity") + + @pytest.mark.parametrize( + "test_panosc_entity_name, expected_icat_relations", + [ + pytest.param("Affiliation", [], id="Affiliation"), + pytest.param("Dataset", [], id="Dataset"), + pytest.param("Document", ["type", "keywords"], id="Document"), + pytest.param("File", [], id="File"), + pytest.param("Instrument", ["facility"], id="Instrument"), + pytest.param("Member", [], id="Member"), + pytest.param("Parameter", ["type", "type"], id="Parameter"), + pytest.param("Person", [], id="Person"), + pytest.param("Sample", ["parameters.type"], id="Sample"), + pytest.param("Technique", [], id="Technique"), + ], + ) + def test_get_icat_relations_for_panosc_non_related_fields( + self, test_panosc_mappings, test_panosc_entity_name, expected_icat_relations, + ): + icat_relations = test_panosc_mappings.get_icat_relations_for_panosc_non_related_fields( # noqa: B950 + test_panosc_entity_name, + ) + assert icat_relations == expected_icat_relations + + @pytest.mark.parametrize( + "test_panosc_entity_name, test_entity_relation, expected_icat_relations", + [ + pytest.param( + "Affiliation", "members", [], id="Affiliation members relation", + ), + pytest.param( + "Affiliation", + "members.document.datasets.samples", + [ + "user.user.investigationUsers.investigation.type", + "user.user.investigationUsers.investigation.keywords", + "user.user.investigationUsers.investigation.datasets.sample" + ".parameters.type", + ], + id="Affiliation members.document.datasets.sample relation", + ), + pytest.param( + "Dataset", + "documents", + ["investigation.type", "investigation.keywords"], + id="Dataset documents relation", + ), + pytest.param( + "Dataset", + "documents.parameters", + [ + "investigation.type", + "investigation.keywords", + "investigation.parameters.type", + "investigation.parameters.type", + ], + id="Dataset documents.parameters relations", + ), + pytest.param( + "Dataset", + "parameters.dataset.documents", + [ + "parameters.type", + "parameters.type", + "parameters.dataset.investigation.type", + "parameters.dataset.investigation.keywords", + ], + id="Dataset parameters.dataset.documents relations", + ), + ], + ) + def test_get_icat_relations_for_non_related_fields_of_panosc_relation( + self, + test_panosc_mappings, + test_panosc_entity_name, + test_entity_relation, + expected_icat_relations, + ): + icat_relations = test_panosc_mappings.get_icat_relations_for_non_related_fields_of_panosc_relation( # noqa: B950 + test_panosc_entity_name, test_entity_relation, + ) + assert icat_relations == expected_icat_relations diff --git a/test/test_filter_order_handler.py b/test/test_filter_order_handler.py new file mode 100644 index 00000000..c02c588d --- /dev/null +++ b/test/test_filter_order_handler.py @@ -0,0 +1,178 @@ +import pytest + +from datagateway_api.src.common.filter_order_handler import FilterOrderHandler +from datagateway_api.src.datagateway_api.icat.filters import ( + PythonICATIncludeFilter, + PythonICATLimitFilter, + PythonICATWhereFilter, +) +from datagateway_api.src.search_api.filters import SearchAPIIncludeFilter + + +class TestFilterOrderHandler: + """ + `merge_python_icat_limit_skip_filters` and`clear_python_icat_order_filters()` are + tested while testing the ICAT backend filters, so tests of these functions won't be + found here + """ + + def test_add_filter(self, icat_query): + test_handler = FilterOrderHandler() + test_filter = PythonICATWhereFilter("id", 2, "eq") + + test_handler.add_filter(test_filter) + + assert test_handler.filters == [test_filter] + + def test_add_filters(self): + test_handler = FilterOrderHandler() + id_filter = PythonICATWhereFilter("id", 2, "eq") + name_filter = PythonICATWhereFilter("name", "New Name", "like") + filter_list = [id_filter, name_filter] + + test_handler.add_filters(filter_list) + + assert test_handler.filters == filter_list + + def test_remove_filter(self): + test_filter = PythonICATWhereFilter("id", 2, "eq") + + test_handler = FilterOrderHandler() + test_handler.add_filter(test_filter) + test_handler.remove_filter(test_filter) + + assert test_handler.filters == [] + + def test_remove_not_added_filter(self): + test_handler = FilterOrderHandler() + test_filter = PythonICATWhereFilter("id", 2, "eq") + + with pytest.raises(ValueError): + test_handler.remove_filter(test_filter) + + def test_sort_filters(self): + limit_filter = PythonICATLimitFilter(10) + where_filter = PythonICATWhereFilter("id", 2, "eq") + + test_handler = FilterOrderHandler() + test_handler.add_filters([limit_filter, where_filter]) + test_handler.sort_filters() + + assert test_handler.filters == [where_filter, limit_filter] + + def test_apply_filters(self, icat_query): + where_filter = PythonICATWhereFilter("id", 2, "eq") + limit_filter = PythonICATLimitFilter(10) + + test_handler = FilterOrderHandler() + test_handler.add_filters([where_filter, limit_filter]) + test_handler.apply_filters(icat_query) + + assert icat_query.conditions == {"id": ["%s = '2'"]} and icat_query.limit == ( + 0, + 10, + ) + + @pytest.mark.parametrize( + "test_panosc_entity_name, test_filters, expected_filters_length," + "expected_num_of_python_include_filters, expected_icat_relations", + [ + pytest.param( + "Dataset", [], 0, 0, [], id="Dataset without related entities", + ), + pytest.param( + "Dataset", + [SearchAPIIncludeFilter(["documents"], "Dataset")], + 2, + 1, + ["investigation.type", "investigation.keywords"], + id="Dataset with related entity", + ), + pytest.param( + "Dataset", + [SearchAPIIncludeFilter(["documents", "instrument"], "Dataset")], + 2, + 1, + [ + "investigation.type", + "investigation.keywords", + "datasetInstruments.instrument.facility", + ], + id="Dataset with related entities", + ), + pytest.param( + "Dataset", + [SearchAPIIncludeFilter(["documents.parameters.document"], "Dataset")], + 2, + 1, + [ + "investigation.type", + "investigation.keywords", + "investigation.parameters.type", + "investigation.parameters.investigation.type", + "investigation.parameters.investigation.keywords", + ], + id="Dataset with nested related entity", + ), + pytest.param( + "Dataset", + [ + SearchAPIIncludeFilter( + [ + "documents.parameters.document", + "parameters.dataset.instrument", + ], + "Dataset", + ), + ], + 2, + 1, + [ + "investigation.type", + "investigation.keywords", + "investigation.parameters.type", + "investigation.parameters.investigation.type", + "investigation.parameters.investigation.keywords", + "parameters.type", + "parameters.dataset.datasetInstruments.instrument.facility", + ], + id="Dataset with nested related entities", + ), + pytest.param( + "Document", + [ + SearchAPIIncludeFilter(["parameters"], "Document"), + PythonICATIncludeFilter(["type", "keywords"]), + ], + 2, + 1, + ["type", "keywords", "parameters.type"], + id="Document with related entity", + ), + ], + ) + def test_add_icat_relations_for_non_related_fields_of_panosc_related_entities( + self, + test_panosc_entity_name, + test_filters, + expected_filters_length, + expected_num_of_python_include_filters, + expected_icat_relations, + ): + handler = FilterOrderHandler() + handler.add_filters(test_filters) + handler.add_icat_relations_for_non_related_fields_of_panosc_related_entities( + test_panosc_entity_name, + ) + + actual_num_of_python_include_filters = 0 + for filter_ in handler.filters: + if type(filter_) == PythonICATIncludeFilter: + actual_num_of_python_include_filters += 1 + assert filter_.included_filters == expected_icat_relations + + assert ( + actual_num_of_python_include_filters + == expected_num_of_python_include_filters + ) + assert len(handler.filters) == expected_filters_length