Skip to content

Commit

Permalink
refactor: add in SearchAPIQueryFilterFactory #259
Browse files Browse the repository at this point in the history
- Added in a generic `QueryFilterFactory` object which inherits from the search API and DataGateway API versions
- Also fixed imports of the DataGateway API specific implementation
  • Loading branch information
MRichards99 committed Nov 4, 2021
1 parent 04d92f7 commit bbc9412
Show file tree
Hide file tree
Showing 6 changed files with 69 additions and 13 deletions.
14 changes: 14 additions & 0 deletions datagateway_api/common/base_query_filter_factory.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,14 @@
from abc import abstractstaticmethod


class QueryFilterFactory(object):
@abstractstaticmethod
def get_query_filter(request_filter):
"""
Given a filter, return a matching Query filter object
:param request_filter: The filter to create the QueryFilter for
:type request_filter: :class:`dict`
:return: The QueryFilter object created
"""
pass
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
import logging

from datagateway_api.common.base_query_filter_factory import QueryFilterFactory
from datagateway_api.common.config import APIConfigOptions, config
from datagateway_api.common.exceptions import (
ApiError,
Expand All @@ -9,7 +10,7 @@
log = logging.getLogger()


class QueryFilterFactory(object):
class DataGatewayAPIQueryFilterFactory(QueryFilterFactory):
@staticmethod
def get_query_filter(request_filter):
"""
Expand Down
6 changes: 4 additions & 2 deletions datagateway_api/common/helpers.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@

from datagateway_api.common.datagateway_api.database import models
from datagateway_api.common.datagateway_api.query_filter_factory import (
QueryFilterFactory,
DataGatewayAPIQueryFilterFactory,
)
from datagateway_api.common.date_handler import DateHandler
from datagateway_api.common.exceptions import (
Expand Down Expand Up @@ -104,7 +104,9 @@ def get_filters_from_query_string():
for arg in request.args:
for value in request.args.getlist(arg):
filters.extend(
QueryFilterFactory.get_query_filter({arg: json.loads(value)}),
DataGatewayAPIQueryFilterFactory.get_query_filter(
{arg: json.loads(value)}
),
)
return filters
except Exception as e:
Expand Down
35 changes: 35 additions & 0 deletions datagateway_api/common/search_api/query_filter_factory.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,35 @@
import logging

from datagateway_api.common.base_query_filter_factory import QueryFilterFactory
from datagateway_api.common.exceptions import FilterError

log = logging.getLogger()


class SearchAPIQueryFilterFactory(QueryFilterFactory):
@staticmethod
def get_query_filter(request_filter):
query_param_name = list(request_filter)[0].lower()
query_filters = []

if query_param_name == "filter":
log.debug(
f"Filter: {request_filter['filter']}, Type: {type(request_filter['filter'])})"
)
for filter_name, filter_input in request_filter["filter"].items():
if filter_name == "where":
pass
elif filter_name == "include":
pass
elif filter_name == "limit":
pass
elif filter_name == "skip":
pass
else:
raise FilterError(
"No valid filter name given within filter query param"
)

return query_filters
else:
raise FilterError(f"Bad filter, please check input: {request_filter}")
20 changes: 12 additions & 8 deletions test/db/test_query_filter_factory.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,15 +9,17 @@
DatabaseWhereFilter,
)
from datagateway_api.common.datagateway_api.query_filter_factory import (
QueryFilterFactory,
DataGatewayAPIQueryFilterFactory,
)


# TODO - Move outside of db/
class TestQueryFilterFactory:
class TestDataGatewayAPIQueryFilterFactory:
@pytest.mark.usefixtures("flask_test_app_db")
def test_valid_distinct_filter(self):
test_filter = QueryFilterFactory.get_query_filter({"distinct": "TEST"})
test_filter = DataGatewayAPIQueryFilterFactory.get_query_filter(
{"distinct": "TEST"}
)
assert isinstance(test_filter[0], DatabaseDistinctFieldFilter)
assert len(test_filter) == 1

Expand All @@ -34,25 +36,27 @@ def test_valid_distinct_filter(self):
],
)
def test_valid_include_filter(self, filter_input):
test_filter = QueryFilterFactory.get_query_filter(filter_input)
test_filter = DataGatewayAPIQueryFilterFactory.get_query_filter(filter_input)
assert isinstance(test_filter[0], DatabaseIncludeFilter)
assert len(test_filter) == 1

@pytest.mark.usefixtures("flask_test_app_db")
def test_valid_limit_filter(self):
test_filter = QueryFilterFactory.get_query_filter({"limit": 10})
test_filter = DataGatewayAPIQueryFilterFactory.get_query_filter({"limit": 10})
assert isinstance(test_filter[0], DatabaseLimitFilter)
assert len(test_filter) == 1

@pytest.mark.usefixtures("flask_test_app_db")
def test_valid_order_filter(self):
test_filter = QueryFilterFactory.get_query_filter({"order": "id DESC"})
test_filter = DataGatewayAPIQueryFilterFactory.get_query_filter(
{"order": "id DESC"}
)
assert isinstance(test_filter[0], DatabaseOrderFilter)
assert len(test_filter) == 1

@pytest.mark.usefixtures("flask_test_app_db")
def test_valid_skip_filter(self):
test_filter = QueryFilterFactory.get_query_filter({"skip": 10})
test_filter = DataGatewayAPIQueryFilterFactory.get_query_filter({"skip": 10})
assert isinstance(test_filter[0], DatabaseSkipFilter)
assert len(test_filter) == 1

Expand All @@ -71,6 +75,6 @@ def test_valid_skip_filter(self):
],
)
def test_valid_where_filter(self, filter_input):
test_filter = QueryFilterFactory.get_query_filter(filter_input)
test_filter = DataGatewayAPIQueryFilterFactory.get_query_filter(filter_input)
assert isinstance(test_filter[0], DatabaseWhereFilter)
assert len(test_filter) == 1
4 changes: 2 additions & 2 deletions test/test_query_filter.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
import pytest

from datagateway_api.common.datagateway_api.query_filter_factory import (
QueryFilterFactory,
DataGatewayAPIQueryFilterFactory,
)
from datagateway_api.common.exceptions import ApiError
from datagateway_api.common.filters import QueryFilter
Expand Down Expand Up @@ -31,4 +31,4 @@ def test_invalid_query_filter_getter(self):
return_value="invalid_backend",
):
with pytest.raises(ApiError):
QueryFilterFactory.get_query_filter({"order": "id DESC"})
DataGatewayAPIQueryFilterFactory.get_query_filter({"order": "id DESC"})

0 comments on commit bbc9412

Please sign in to comment.