diff --git a/datagateway_api/src/datagateway_api/icat/filters.py b/datagateway_api/src/datagateway_api/icat/filters.py index c114c9d5..bf69f3f0 100644 --- a/datagateway_api/src/datagateway_api/icat/filters.py +++ b/datagateway_api/src/datagateway_api/icat/filters.py @@ -210,14 +210,21 @@ def apply_filter(self, query): class PythonICATSkipFilter(SkipFilter): - def __init__(self, skip_value): + def __init__(self, skip_value, filter_use="datagateway_api"): super().__init__(skip_value) + self.filter_use = filter_use def apply_filter(self, query): - icat_properties = get_icat_properties( - Config.config.datagateway_api.icat_url, - Config.config.datagateway_api.icat_check_cert, - ) + if self.filter_use == "datagateway_api": + icat_properties = get_icat_properties( + Config.config.datagateway_api.icat_url, + Config.config.datagateway_api.icat_check_cert, + ) + else: + icat_properties = get_icat_properties( + Config.config.search_api.icat_url, + Config.config.search_api.icat_check_cert, + ) icat_set_limit(query, self.skip_value, icat_properties["maxEntities"]) diff --git a/datagateway_api/src/search_api/filters.py b/datagateway_api/src/search_api/filters.py index 00249a06..7e28ef16 100644 --- a/datagateway_api/src/search_api/filters.py +++ b/datagateway_api/src/search_api/filters.py @@ -19,7 +19,7 @@ def apply_filter(self, query): class SearchAPISkipFilter(PythonICATSkipFilter): def __init__(self, skip_value): - super().__init__(skip_value) + super().__init__(skip_value, filter_use="search_api") def apply_filter(self, query): return super().apply_filter(query) diff --git a/test/search_api/conftest.py b/test/search_api/conftest.py new file mode 100644 index 00000000..423d5e27 --- /dev/null +++ b/test/search_api/conftest.py @@ -0,0 +1,22 @@ +from icat.client import Client +from icat.query import Query +import pytest + +from datagateway_api.src.common.config import Config + + +@pytest.fixture(scope="package") +def icat_client(): + client = Client( + Config.config.search_api.icat_url, + checkCert=Config.config.search_api.icat_check_cert, + ) + client.login( + Config.config.test_mechanism, Config.config.test_user_credentials.dict(), + ) + return client + + +@pytest.fixture() +def icat_query(icat_client): + return Query(icat_client, "Investigation") diff --git a/test/search_api/test_limit_filter.py b/test/search_api/test_limit_filter.py new file mode 100644 index 00000000..f0245b72 --- /dev/null +++ b/test/search_api/test_limit_filter.py @@ -0,0 +1,28 @@ +import pytest + +from datagateway_api.src.common.exceptions import FilterError +from datagateway_api.src.search_api.filters import SearchAPILimitFilter + + +class TestSearchAPILimitFilter: + @pytest.mark.parametrize( + "limit_value", + [ + pytest.param(10, id="typical"), + pytest.param(0, id="low boundary"), + pytest.param(9999, id="high boundary"), + ], + ) + def test_valid_limit_value(self, icat_query, limit_value): + test_filter = SearchAPILimitFilter(limit_value) + test_filter.apply_filter(icat_query) + + assert icat_query.limit == (0, limit_value) + + @pytest.mark.parametrize( + "limit_value", + [pytest.param(-50, id="extreme invalid"), pytest.param(-1, id="boundary")], + ) + def test_invalid_limit_value(self, icat_query, limit_value): + with pytest.raises(FilterError): + SearchAPILimitFilter(limit_value) diff --git a/test/search_api/test_skip_filter.py b/test/search_api/test_skip_filter.py new file mode 100644 index 00000000..9d8c0093 --- /dev/null +++ b/test/search_api/test_skip_filter.py @@ -0,0 +1,31 @@ +import pytest + +from datagateway_api.src.common.config import Config +from datagateway_api.src.common.exceptions import FilterError +from datagateway_api.src.common.helpers import get_icat_properties +from datagateway_api.src.search_api.filters import SearchAPISkipFilter + + +class TestSearchAPISkipFilter: + @pytest.mark.parametrize( + "skip_value", [pytest.param(10, id="typical"), pytest.param(0, id="boundary")], + ) + def test_valid_skip_value(self, icat_query, skip_value): + test_filter = SearchAPISkipFilter(skip_value) + test_filter.apply_filter(icat_query) + + assert icat_query.limit == ( + skip_value, + get_icat_properties( + Config.config.search_api.icat_url, + Config.config.search_api.icat_check_cert, + )["maxEntities"], + ) + + @pytest.mark.parametrize( + "skip_value", + [pytest.param(-375, id="extreme invalid"), pytest.param(-1, id="boundary")], + ) + def test_invalid_skip_value(self, icat_query, skip_value): + with pytest.raises(FilterError): + SearchAPISkipFilter(skip_value)