diff --git a/datagateway_api/src/search_api/query_filter_factory.py b/datagateway_api/src/search_api/query_filter_factory.py index 08cea9fa..fd669d63 100644 --- a/datagateway_api/src/search_api/query_filter_factory.py +++ b/datagateway_api/src/search_api/query_filter_factory.py @@ -19,7 +19,7 @@ class SearchAPIQueryFilterFactory(QueryFilterFactory): @staticmethod - def get_query_filter(request_filter, entity_name=None): + def get_query_filter(request_filter, entity_name=None, related_entity_name=None): """ Given a filter, return a list of matching query filter objects @@ -30,6 +30,9 @@ def get_query_filter(request_filter, entity_name=None): entity - this is needed for when there is a text operator inside a where filter :type entity_name: :class:`str` + :param related_entity_name: Entity name used when calling `get_where_filter()` + and `get_include_filter()` + :type related_entity_name: :class:`str` :return: The list of query filter objects created :raises FilterError: If the filter name is not recognised """ @@ -43,14 +46,14 @@ def get_query_filter(request_filter, entity_name=None): log.info("where JSON object found") query_filters.extend( SearchAPIQueryFilterFactory.get_where_filter( - filter_input, entity_name, + filter_input, entity_name, related_entity_name, ), ) elif filter_name == "include": log.info("include JSON object found") query_filters.extend( SearchAPIQueryFilterFactory.get_include_filter( - filter_input, entity_name, + filter_input, entity_name, related_entity_name, ), ) elif filter_name == "limit": @@ -80,7 +83,7 @@ def get_query_filter(request_filter, entity_name=None): return query_filters @staticmethod - def get_where_filter(where_filter_input, entity_name): + def get_where_filter(where_filter_input, entity_name, related_entity_name=None): """ Given a where filter input, return a list of `NestedWhereFilters` and/ or `SearchAPIWhereFilter` objects @@ -99,8 +102,12 @@ def get_where_filter(where_filter_input, entity_name): filter so that the value provided can be matched with the relevant text operator fields for the entity. :type entity_name: :class:`str` + :param related_entity_name: Entity name of a related entity, used for getting + text operator fields for said related entity + :type related_entity_name: :class:`str` :return: The list of `NestedWhereFilters` and/ or `SearchAPIWhereFilter` objects created + :raises SearchAPIError: If there are no text operator fields on the entity """ where_filters = [] @@ -134,10 +141,16 @@ def get_where_filter(where_filter_input, entity_name): elif list(where_filter_input.keys())[0] == "text": log.debug("Text operator found within JSON where object") try: - entity_class = getattr(search_api_models, entity_name) + # If there's a related entity name, fetch the text operator fields for + # that entity. This serves in use cases where there's a WHERE filter + # with a text operator on an included/related entity + entity_class_name = ( + related_entity_name if related_entity_name else entity_name + ) + entity_class = getattr(search_api_models, entity_class_name) except AttributeError as e: raise SearchAPIError( - f"No text operator fields have been defined for {entity_name}" + f"No text operator fields have been defined for {entity_class_name}" f", {e.args}", ) @@ -189,7 +202,7 @@ def get_where_filter(where_filter_input, entity_name): return where_filters @staticmethod - def get_include_filter(include_filter_input, entity_name): + def get_include_filter(include_filter_input, entity_name, related_entity_name=None): """ Given an include filter input, return a list of `SearchAPIIncludeFilter` and any `NestedWhereFilters` and/ or `SearchAPIWhereFilter` objects if there is a scope @@ -202,6 +215,13 @@ def get_include_filter(include_filter_input, entity_name): `SearchAPIIncludeFilter` and any `NestedWhereFilters` and/or `SearchAPIWhereFilter` objects :type include_filter_input: :class:`list` + :param entity_name: Entity name of the endpoint or the name of the included + entity - this is needed for when there is a text operator inside a where + filter so that the value provided can be matched with the relevant text + operator fields for the entity. + :type entity_name: :class:`str` + :param related_entity_name: Entity name of a related entity + :type related_entity_name: :class:`str` :return: The list of `SearchAPIIncludeFilter` and any `NestedWhereFilters` and/ or `SearchAPIWhereFilter` objects created :raises FilterError: If scope filter has a limit or skip filter @@ -230,15 +250,20 @@ def get_include_filter(include_filter_input, entity_name): ) try: + entity_class_name = ( + related_entity_name if related_entity_name else entity_name + ) # Get related field name in entity name format for recursive call related_entity_name = mappings.get_panosc_related_entity_name( - entity_name, included_entity, + entity_class_name, included_entity, ) except SearchAPIError as e: # If the function call errors, it's a client issue at this point raise FilterError(e) scope_query_filters = SearchAPIQueryFilterFactory.get_query_filter( - {"filter": related_model["scope"]}, related_entity_name, + {"filter": related_model["scope"]}, + entity_name, + related_entity_name, ) for scope_query_filter in scope_query_filters: @@ -259,6 +284,9 @@ def get_include_filter(include_filter_input, entity_name): ] = f"{included_entity}.{included_filter}" query_filters.extend(scope_query_filters) + # Flush related entity name so a bug doesn't occur with multiple related + # models + related_entity_name = None if not nested_include: query_filters.append(