From e820c3139dba5bad4f383849a82dbb115da6d3f3 Mon Sep 17 00:00:00 2001 From: Matthew Richards Date: Mon, 20 Dec 2021 18:08:12 +0000 Subject: [PATCH] fix various edge cases for `SearchAPIQueryFilterFactory` #260 --- .../src/search_api/query_filter_factory.py | 23 +++++++++++++++---- 1 file changed, 18 insertions(+), 5 deletions(-) diff --git a/datagateway_api/src/search_api/query_filter_factory.py b/datagateway_api/src/search_api/query_filter_factory.py index d27d199d..883a3fde 100644 --- a/datagateway_api/src/search_api/query_filter_factory.py +++ b/datagateway_api/src/search_api/query_filter_factory.py @@ -1,7 +1,7 @@ import logging from datagateway_api.src.common.base_query_filter_factory import QueryFilterFactory -from datagateway_api.src.common.exceptions import FilterError +from datagateway_api.src.common.exceptions import FilterError, SearchAPIError from datagateway_api.src.search_api.filters import ( SearchAPIIncludeFilter, SearchAPILimitFilter, @@ -9,6 +9,8 @@ SearchAPIWhereFilter, ) from datagateway_api.src.search_api.nested_where_filters import NestedWhereFilters +from datagateway_api.src.search_api.panosc_mappings import mappings +from datagateway_api.src.search_api.query import SearchAPIQuery log = logging.getLogger() @@ -43,7 +45,9 @@ def get_query_filter(request_filter, entity_name=None): ) elif filter_name == "include": query_filters.extend( - SearchAPIQueryFilterFactory.get_include_filter(filter_input), + SearchAPIQueryFilterFactory.get_include_filter( + filter_input, entity_name, + ), ) elif filter_name == "limit": query_filters.append(SearchAPILimitFilter(filter_input)) @@ -90,6 +94,7 @@ def get_where_filter(where_filter_input, entity_name): :return: The list of `NestedWhereFilters` and/ or `SearchAPIWhereFilter` objects created """ + where_filters = [] if ( list(where_filter_input.keys())[0] == "and" @@ -114,6 +119,7 @@ def get_where_filter(where_filter_input, entity_name): conditional_where_filters[:-1], conditional_where_filters[-1], boolean_operator, + SearchAPIQuery(entity_name), ) where_filters.append(nested) elif list(where_filter_input.keys())[0] == "text": @@ -164,7 +170,7 @@ def get_where_filter(where_filter_input, entity_name): return where_filters @staticmethod - def get_include_filter(include_filter_input): + def get_include_filter(include_filter_input, entity_name): """ Given an include filter input, return a list of `SearchAPIIncludeFilter` and any `NestedWhereFilters` and/ or `SearchAPIWhereFilter` objects if there is a scope @@ -196,9 +202,16 @@ def get_include_filter(include_filter_input): "Bad Include filter: Scope filter cannot have a skip filter", ) - # Scope filter can have WHERE and/ or INCLUDE filters + try: + # Get related field name in entity name format for recursive call + related_entity_name = mappings.get_panosc_related_entity_name( + entity_name, included_entity, + ) + except SearchAPIError as e: + # If the function call errors, it's a client issue at this point + raise FilterError(e) scope_query_filters = SearchAPIQueryFilterFactory.get_query_filter( - {"filter": related_model["scope"]}, included_entity, + {"filter": related_model["scope"]}, related_entity_name, ) for scope_query_filter in scope_query_filters: