Skip to content

Commit

Permalink
pass in entity name to get_query_filter() to be used by text operator
Browse files Browse the repository at this point in the history
  • Loading branch information
MRichards99 committed Nov 29, 2021
1 parent c39795c commit 1d24021
Show file tree
Hide file tree
Showing 4 changed files with 22 additions and 9 deletions.
5 changes: 4 additions & 1 deletion datagateway_api/src/common/base_query_filter_factory.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,12 +3,15 @@

class QueryFilterFactory(object):
@abstractstaticmethod
def get_query_filter(request_filter):
def get_query_filter(request_filter, entity_name=None):
"""
Given a filter, return a matching Query filter object
:param request_filter: The filter to create the QueryFilter for
:type request_filter: :class:`dict`
:param entity_name: Entity name of the endpoint, optional (only used for search
API, not DataGateway API)
:type entity_name: :class:`str`
:return: The QueryFilter object created
"""
pass
9 changes: 7 additions & 2 deletions datagateway_api/src/common/helpers.py
Original file line number Diff line number Diff line change
Expand Up @@ -87,14 +87,17 @@ def is_valid_json(string):
return True


def get_filters_from_query_string(api_type):
def get_filters_from_query_string(api_type, entity_name=None):
"""
Gets a list of filters from the query_strings arg,value pairs, and returns a list of
QueryFilter Objects
:param api_type: Type of API this function is being used for i.e. DataGateway API or
Search API
:type api_type: :class:`str`
:param entity_name: Entity name of the endpoint, optional (only used for search
API, not DataGateway API)
:type entity_name: :class:`str`
:raises ApiError: If `api_type` isn't a valid value
:return: The list of filters
"""
Expand All @@ -117,7 +120,9 @@ def get_filters_from_query_string(api_type):
for arg in request.args:
for value in request.args.getlist(arg):
filters.extend(
QueryFilterFactory.get_query_filter({arg: json.loads(value)}),
QueryFilterFactory.get_query_filter(
{arg: json.loads(value)}, entity_name,
),
)
return filters
except Exception as e:
Expand Down
7 changes: 6 additions & 1 deletion datagateway_api/src/datagateway_api/query_filter_factory.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@

class DataGatewayAPIQueryFilterFactory(QueryFilterFactory):
@staticmethod
def get_query_filter(request_filter):
def get_query_filter(request_filter, entity_name=None):
"""
Given a filter, return a matching Query filter object
Expand All @@ -23,6 +23,11 @@ def get_query_filter(request_filter):
:param request_filter: The filter to create the QueryFilter for
:type request_filter: :class:`dict`
:param entity_name: Not utilised in DataGateway API implementation of this
static function, used in the search API. It is part of the method signature
as the same function call (called in `get_filters_from_query_string()`) is
used for both implementations
:type entity_name: :class:`str`
:return: The QueryFilter object created
:raises ApiError: If the backend type contains an invalid value
:raises FilterError: If the filter name is not recognised
Expand Down
10 changes: 5 additions & 5 deletions datagateway_api/src/resources/search_api_endpoints.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@ def get_search_endpoint(name):

class Endpoint(Resource):
def get(self):
filters = get_filters_from_query_string("search_api")
filters = get_filters_from_query_string("search_api", name)
log.debug("Filters: %s", filters)
"""
TODO - Need to return similar to
Expand Down Expand Up @@ -46,7 +46,7 @@ def get_single_endpoint(name):

class EndpointWithID(Resource):
def get(self, pid):
filters = get_filters_from_query_string("search_api")
filters = get_filters_from_query_string("search_api", name)
log.debug("Filters: %s", filters)
# TODO - Add return
pass
Expand All @@ -65,7 +65,7 @@ def get_number_count_endpoint(name):
class CountEndpoint(Resource):
def get(self):
# Only WHERE included on count endpoints
filters = get_filters_from_query_string("search_api")
filters = get_filters_from_query_string("search_api", name)
log.debug("Filters: %s", filters)
# TODO - Add return
pass
Expand All @@ -83,7 +83,7 @@ def get_files_endpoint(name):

class FilesEndpoint(Resource):
def get(self, pid):
filters = get_filters_from_query_string("search_api")
filters = get_filters_from_query_string("search_api", name)
log.debug("Filters: %s", filters)
# TODO - Add return
pass
Expand All @@ -102,7 +102,7 @@ def get_number_count_files_endpoint(name):
class CountFilesEndpoint(Resource):
def get(self, pid):
# Only WHERE included on count endpoints
filters = get_filters_from_query_string("search_api")
filters = get_filters_from_query_string("search_api", name)
log.debug("Filters: %s", filters)
# TODO - Add return
pass
Expand Down

0 comments on commit 1d24021

Please sign in to comment.