Skip to content

Commit

Permalink
refactor: make QueryFilterFactory return a list of filters #259
Browse files Browse the repository at this point in the history
- This is in preparation to add a `SearchAPIQueryFilterFactory` where a single query parameter will have multiple filters
- This commit also fixes the tests which impact this change
  • Loading branch information
MRichards99 committed Nov 4, 2021
1 parent 8ea6e2d commit 04d92f7
Show file tree
Hide file tree
Showing 3 changed files with 26 additions and 27 deletions.
12 changes: 6 additions & 6 deletions datagateway_api/common/datagateway_api/query_filter_factory.py
Original file line number Diff line number Diff line change
Expand Up @@ -57,18 +57,18 @@ def get_query_filter(request_filter):
field = list(request_filter[filter_name].keys())[0]
operation = list(request_filter[filter_name][field].keys())[0]
value = request_filter[filter_name][field][operation]
return WhereFilter(field, value, operation)
return [WhereFilter(field, value, operation)]
elif filter_name == "order":
field = request_filter["order"].split(" ")[0]
direction = request_filter["order"].split(" ")[1]
return OrderFilter(field, direction)
return [OrderFilter(field, direction)]
elif filter_name == "skip":
return SkipFilter(request_filter["skip"])
return [SkipFilter(request_filter["skip"])]
elif filter_name == "limit":
return LimitFilter(request_filter["limit"])
return [LimitFilter(request_filter["limit"])]
elif filter_name == "include":
return IncludeFilter(request_filter["include"])
return [IncludeFilter(request_filter["include"])]
elif filter_name == "distinct":
return DistinctFieldFilter(request_filter["distinct"])
return [DistinctFieldFilter(request_filter["distinct"])]
else:
raise FilterError(f" Bad filter: {request_filter}")
2 changes: 1 addition & 1 deletion datagateway_api/common/helpers.py
Original file line number Diff line number Diff line change
Expand Up @@ -103,7 +103,7 @@ def get_filters_from_query_string():
filters = []
for arg in request.args:
for value in request.args.getlist(arg):
filters.append(
filters.extend(
QueryFilterFactory.get_query_filter({arg: json.loads(value)}),
)
return filters
Expand Down
39 changes: 19 additions & 20 deletions test/db/test_query_filter_factory.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,13 +13,13 @@
)


# TODO - Move outside of db/
class TestQueryFilterFactory:
@pytest.mark.usefixtures("flask_test_app_db")
def test_valid_distinct_filter(self):
assert isinstance(
QueryFilterFactory.get_query_filter({"distinct": "TEST"}),
DatabaseDistinctFieldFilter,
)
test_filter = QueryFilterFactory.get_query_filter({"distinct": "TEST"})
assert isinstance(test_filter[0], DatabaseDistinctFieldFilter)
assert len(test_filter) == 1

@pytest.mark.usefixtures("flask_test_app_db")
@pytest.mark.parametrize(
Expand All @@ -34,28 +34,27 @@ def test_valid_distinct_filter(self):
],
)
def test_valid_include_filter(self, filter_input):
assert isinstance(
QueryFilterFactory.get_query_filter(filter_input), DatabaseIncludeFilter,
)
test_filter = QueryFilterFactory.get_query_filter(filter_input)
assert isinstance(test_filter[0], DatabaseIncludeFilter)
assert len(test_filter) == 1

@pytest.mark.usefixtures("flask_test_app_db")
def test_valid_limit_filter(self):
assert isinstance(
QueryFilterFactory.get_query_filter({"limit": 10}), DatabaseLimitFilter,
)
test_filter = QueryFilterFactory.get_query_filter({"limit": 10})
assert isinstance(test_filter[0], DatabaseLimitFilter)
assert len(test_filter) == 1

@pytest.mark.usefixtures("flask_test_app_db")
def test_valid_order_filter(self):
assert isinstance(
QueryFilterFactory.get_query_filter({"order": "id DESC"}),
DatabaseOrderFilter,
)
test_filter = QueryFilterFactory.get_query_filter({"order": "id DESC"})
assert isinstance(test_filter[0], DatabaseOrderFilter)
assert len(test_filter) == 1

@pytest.mark.usefixtures("flask_test_app_db")
def test_valid_skip_filter(self):
assert isinstance(
QueryFilterFactory.get_query_filter({"skip": 10}), DatabaseSkipFilter,
)
test_filter = QueryFilterFactory.get_query_filter({"skip": 10})
assert isinstance(test_filter[0], DatabaseSkipFilter)
assert len(test_filter) == 1

@pytest.mark.usefixtures("flask_test_app_db")
@pytest.mark.parametrize(
Expand All @@ -72,6 +71,6 @@ def test_valid_skip_filter(self):
],
)
def test_valid_where_filter(self, filter_input):
assert isinstance(
QueryFilterFactory.get_query_filter(filter_input), DatabaseWhereFilter,
)
test_filter = QueryFilterFactory.get_query_filter(filter_input)
assert isinstance(test_filter[0], DatabaseWhereFilter)
assert len(test_filter) == 1

0 comments on commit 04d92f7

Please sign in to comment.