diff --git a/datagateway_api/common/base_query_filter_factory.py b/datagateway_api/common/base_query_filter_factory.py new file mode 100644 index 00000000..89405f78 --- /dev/null +++ b/datagateway_api/common/base_query_filter_factory.py @@ -0,0 +1,14 @@ +from abc import abstractstaticmethod + + +class QueryFilterFactory(object): + @abstractstaticmethod + def get_query_filter(request_filter): + """ + Given a filter, return a matching Query filter object + + :param request_filter: The filter to create the QueryFilter for + :type request_filter: :class:`dict` + :return: The QueryFilter object created + """ + pass diff --git a/datagateway_api/common/datagateway_api/query_filter_factory.py b/datagateway_api/common/datagateway_api/query_filter_factory.py index 5d3830f8..2ad97b0e 100644 --- a/datagateway_api/common/datagateway_api/query_filter_factory.py +++ b/datagateway_api/common/datagateway_api/query_filter_factory.py @@ -1,5 +1,6 @@ import logging +from datagateway_api.common.base_query_filter_factory import QueryFilterFactory from datagateway_api.common.config import APIConfigOptions, config from datagateway_api.common.exceptions import ( ApiError, @@ -9,7 +10,7 @@ log = logging.getLogger() -class QueryFilterFactory(object): +class DataGatewayAPIQueryFilterFactory(QueryFilterFactory): @staticmethod def get_query_filter(request_filter): """ diff --git a/datagateway_api/common/helpers.py b/datagateway_api/common/helpers.py index 74ed1ac4..64a126b4 100644 --- a/datagateway_api/common/helpers.py +++ b/datagateway_api/common/helpers.py @@ -10,7 +10,7 @@ from datagateway_api.common.datagateway_api.database import models from datagateway_api.common.datagateway_api.query_filter_factory import ( - QueryFilterFactory, + DataGatewayAPIQueryFilterFactory, ) from datagateway_api.common.date_handler import DateHandler from datagateway_api.common.exceptions import ( @@ -104,7 +104,9 @@ def get_filters_from_query_string(): for arg in request.args: for value in request.args.getlist(arg): filters.extend( - QueryFilterFactory.get_query_filter({arg: json.loads(value)}), + DataGatewayAPIQueryFilterFactory.get_query_filter( + {arg: json.loads(value)} + ), ) return filters except Exception as e: diff --git a/datagateway_api/common/search_api/query_filter_factory.py b/datagateway_api/common/search_api/query_filter_factory.py new file mode 100644 index 00000000..2bb312c5 --- /dev/null +++ b/datagateway_api/common/search_api/query_filter_factory.py @@ -0,0 +1,35 @@ +import logging + +from datagateway_api.common.base_query_filter_factory import QueryFilterFactory +from datagateway_api.common.exceptions import FilterError + +log = logging.getLogger() + + +class SearchAPIQueryFilterFactory(QueryFilterFactory): + @staticmethod + def get_query_filter(request_filter): + query_param_name = list(request_filter)[0].lower() + query_filters = [] + + if query_param_name == "filter": + log.debug( + f"Filter: {request_filter['filter']}, Type: {type(request_filter['filter'])})" + ) + for filter_name, filter_input in request_filter["filter"].items(): + if filter_name == "where": + pass + elif filter_name == "include": + pass + elif filter_name == "limit": + pass + elif filter_name == "skip": + pass + else: + raise FilterError( + "No valid filter name given within filter query param" + ) + + return query_filters + else: + raise FilterError(f"Bad filter, please check input: {request_filter}") diff --git a/test/db/test_query_filter_factory.py b/test/db/test_query_filter_factory.py index 4171174b..fe0a6113 100644 --- a/test/db/test_query_filter_factory.py +++ b/test/db/test_query_filter_factory.py @@ -9,15 +9,17 @@ DatabaseWhereFilter, ) from datagateway_api.common.datagateway_api.query_filter_factory import ( - QueryFilterFactory, + DataGatewayAPIQueryFilterFactory, ) # TODO - Move outside of db/ -class TestQueryFilterFactory: +class TestDataGatewayAPIQueryFilterFactory: @pytest.mark.usefixtures("flask_test_app_db") def test_valid_distinct_filter(self): - test_filter = QueryFilterFactory.get_query_filter({"distinct": "TEST"}) + test_filter = DataGatewayAPIQueryFilterFactory.get_query_filter( + {"distinct": "TEST"} + ) assert isinstance(test_filter[0], DatabaseDistinctFieldFilter) assert len(test_filter) == 1 @@ -34,25 +36,27 @@ def test_valid_distinct_filter(self): ], ) def test_valid_include_filter(self, filter_input): - test_filter = QueryFilterFactory.get_query_filter(filter_input) + test_filter = DataGatewayAPIQueryFilterFactory.get_query_filter(filter_input) assert isinstance(test_filter[0], DatabaseIncludeFilter) assert len(test_filter) == 1 @pytest.mark.usefixtures("flask_test_app_db") def test_valid_limit_filter(self): - test_filter = QueryFilterFactory.get_query_filter({"limit": 10}) + test_filter = DataGatewayAPIQueryFilterFactory.get_query_filter({"limit": 10}) assert isinstance(test_filter[0], DatabaseLimitFilter) assert len(test_filter) == 1 @pytest.mark.usefixtures("flask_test_app_db") def test_valid_order_filter(self): - test_filter = QueryFilterFactory.get_query_filter({"order": "id DESC"}) + test_filter = DataGatewayAPIQueryFilterFactory.get_query_filter( + {"order": "id DESC"} + ) assert isinstance(test_filter[0], DatabaseOrderFilter) assert len(test_filter) == 1 @pytest.mark.usefixtures("flask_test_app_db") def test_valid_skip_filter(self): - test_filter = QueryFilterFactory.get_query_filter({"skip": 10}) + test_filter = DataGatewayAPIQueryFilterFactory.get_query_filter({"skip": 10}) assert isinstance(test_filter[0], DatabaseSkipFilter) assert len(test_filter) == 1 @@ -71,6 +75,6 @@ def test_valid_skip_filter(self): ], ) def test_valid_where_filter(self, filter_input): - test_filter = QueryFilterFactory.get_query_filter(filter_input) + test_filter = DataGatewayAPIQueryFilterFactory.get_query_filter(filter_input) assert isinstance(test_filter[0], DatabaseWhereFilter) assert len(test_filter) == 1 diff --git a/test/test_query_filter.py b/test/test_query_filter.py index 3304fa40..3ce98717 100644 --- a/test/test_query_filter.py +++ b/test/test_query_filter.py @@ -3,7 +3,7 @@ import pytest from datagateway_api.common.datagateway_api.query_filter_factory import ( - QueryFilterFactory, + DataGatewayAPIQueryFilterFactory, ) from datagateway_api.common.exceptions import ApiError from datagateway_api.common.filters import QueryFilter @@ -31,4 +31,4 @@ def test_invalid_query_filter_getter(self): return_value="invalid_backend", ): with pytest.raises(ApiError): - QueryFilterFactory.get_query_filter({"order": "id DESC"}) + DataGatewayAPIQueryFilterFactory.get_query_filter({"order": "id DESC"})