diff --git a/datagateway_api/src/common/base_query_filter_factory.py b/datagateway_api/src/common/base_query_filter_factory.py index 89405f78..d9a9c4a4 100644 --- a/datagateway_api/src/common/base_query_filter_factory.py +++ b/datagateway_api/src/common/base_query_filter_factory.py @@ -3,12 +3,15 @@ class QueryFilterFactory(object): @abstractstaticmethod - def get_query_filter(request_filter): + def get_query_filter(request_filter, entity_name=None): """ Given a filter, return a matching Query filter object :param request_filter: The filter to create the QueryFilter for :type request_filter: :class:`dict` + :param entity_name: Entity name of the endpoint, optional (only used for search + API, not DataGateway API) + :type entity_name: :class:`str` :return: The QueryFilter object created """ pass diff --git a/datagateway_api/src/common/helpers.py b/datagateway_api/src/common/helpers.py index f6b3c24c..79d8c7ec 100644 --- a/datagateway_api/src/common/helpers.py +++ b/datagateway_api/src/common/helpers.py @@ -87,7 +87,7 @@ def is_valid_json(string): return True -def get_filters_from_query_string(api_type): +def get_filters_from_query_string(api_type, entity_name=None): """ Gets a list of filters from the query_strings arg,value pairs, and returns a list of QueryFilter Objects @@ -95,6 +95,9 @@ def get_filters_from_query_string(api_type): :param api_type: Type of API this function is being used for i.e. DataGateway API or Search API :type api_type: :class:`str` + :param entity_name: Entity name of the endpoint, optional (only used for search + API, not DataGateway API) + :type entity_name: :class:`str` :raises ApiError: If `api_type` isn't a valid value :return: The list of filters """ @@ -117,7 +120,9 @@ def get_filters_from_query_string(api_type): for arg in request.args: for value in request.args.getlist(arg): filters.extend( - QueryFilterFactory.get_query_filter({arg: json.loads(value)}), + QueryFilterFactory.get_query_filter( + {arg: json.loads(value)}, entity_name, + ), ) return filters except Exception as e: diff --git a/datagateway_api/src/datagateway_api/query_filter_factory.py b/datagateway_api/src/datagateway_api/query_filter_factory.py index 7f5f838f..e50a3eb7 100644 --- a/datagateway_api/src/datagateway_api/query_filter_factory.py +++ b/datagateway_api/src/datagateway_api/query_filter_factory.py @@ -12,7 +12,7 @@ class DataGatewayAPIQueryFilterFactory(QueryFilterFactory): @staticmethod - def get_query_filter(request_filter): + def get_query_filter(request_filter, entity_name=None): """ Given a filter, return a matching Query filter object @@ -23,6 +23,11 @@ def get_query_filter(request_filter): :param request_filter: The filter to create the QueryFilter for :type request_filter: :class:`dict` + :param entity_name: Not utilised in DataGateway API implementation of this + static function, used in the search API. It is part of the method signature + as the same function call (called in `get_filters_from_query_string()`) is + used for both implementations + :type entity_name: :class:`str` :return: The QueryFilter object created :raises ApiError: If the backend type contains an invalid value :raises FilterError: If the filter name is not recognised diff --git a/datagateway_api/src/resources/search_api_endpoints.py b/datagateway_api/src/resources/search_api_endpoints.py index 0290b451..e9fd1353 100644 --- a/datagateway_api/src/resources/search_api_endpoints.py +++ b/datagateway_api/src/resources/search_api_endpoints.py @@ -17,7 +17,7 @@ def get_search_endpoint(name): class Endpoint(Resource): def get(self): - filters = get_filters_from_query_string("search_api") + filters = get_filters_from_query_string("search_api", name) log.debug("Filters: %s", filters) """ TODO - Need to return similar to @@ -46,7 +46,7 @@ def get_single_endpoint(name): class EndpointWithID(Resource): def get(self, pid): - filters = get_filters_from_query_string("search_api") + filters = get_filters_from_query_string("search_api", name) log.debug("Filters: %s", filters) # TODO - Add return pass @@ -65,7 +65,7 @@ def get_number_count_endpoint(name): class CountEndpoint(Resource): def get(self): # Only WHERE included on count endpoints - filters = get_filters_from_query_string("search_api") + filters = get_filters_from_query_string("search_api", name) log.debug("Filters: %s", filters) # TODO - Add return pass @@ -83,7 +83,7 @@ def get_files_endpoint(name): class FilesEndpoint(Resource): def get(self, pid): - filters = get_filters_from_query_string("search_api") + filters = get_filters_from_query_string("search_api", name) log.debug("Filters: %s", filters) # TODO - Add return pass @@ -102,7 +102,7 @@ def get_number_count_files_endpoint(name): class CountFilesEndpoint(Resource): def get(self, pid): # Only WHERE included on count endpoints - filters = get_filters_from_query_string("search_api") + filters = get_filters_from_query_string("search_api", name) log.debug("Filters: %s", filters) # TODO - Add return pass