diff --git a/common/database_helpers.py b/common/database_helpers.py index d9645fb9..204c78e8 100644 --- a/common/database_helpers.py +++ b/common/database_helpers.py @@ -212,7 +212,7 @@ class IncludeFilter(QueryFilter): precedence = 5 def __init__(self, included_filters): - self.included_filters = included_filters + self.included_filters = included_filters["include"] def apply_filter(self, query): query.include_related_entities = True @@ -259,6 +259,9 @@ def __init__(self): def add_filter(self, filter): self.filters.append(filter) + def add_filters(self, filters): + self.filters.extend(filters) + def sort_filters(self): """ Sorts the filters according to the order of operations @@ -346,11 +349,7 @@ def get_filtered_read_query_results(filter_handler, filters, query): :return: The results of the query as a list of dictionaries """ try: - for query_filter in filters: - if len(query_filter) == 0: - pass - else: - filter_handler.add_filter(QueryFilterFactory.get_query_filter(query_filter)) + filter_handler.add_filters(filters) filter_handler.apply_filters(query) results = query.get_all_results() if query.is_distinct_fields_query: @@ -372,8 +371,8 @@ def _get_results_with_include(filters, results): :return: A list of nested dictionaries representing the entity results """ for query_filter in filters: - if list(query_filter)[0].lower() == "include": - return [x.to_nested_dict(query_filter["include"]) for x in results] + if type(query_filter) is IncludeFilter: + return [x.to_nested_dict(query_filter.included_filters) for x in results] def _get_distinct_fields_as_dicts(results): @@ -424,11 +423,7 @@ def get_filtered_row_count(table, filters): log.info(f" getting count for {table.__tablename__}") count_query = CountQuery(table) filter_handler = FilterOrderHandler() - for query_filter in filters: - if len(query_filter) == 0: - pass - else: - filter_handler.add_filter(QueryFilterFactory.get_query_filter(query_filter)) + filter_handler.add_filters(filters) filter_handler.apply_filters(count_query) return count_query.get_count() @@ -502,11 +497,7 @@ def get_investigations_for_user_count(user_id, filters): """ count_query = UserInvestigationsCountQuery(user_id) filter_handler = FilterOrderHandler() - for query_filter in filters: - if len(query_filter) == 0: - pass - else: - filter_handler.add_filter(QueryFilterFactory.get_query_filter(query_filter)) + filter_handler.add_filters(filters) filter_handler.apply_filters(count_query) return count_query.get_count() diff --git a/common/helpers.py b/common/helpers.py index f22cd1e6..60ba6720 100644 --- a/common/helpers.py +++ b/common/helpers.py @@ -6,6 +6,7 @@ from flask_restful import reqparse from sqlalchemy.exc import IntegrityError +from common.database_helpers import QueryFilterFactory from common.exceptions import MissingRecordError, BadFilterError, AuthenticationError, BadRequestError from common.models.db_models import SESSION from common.session_manager import session_manager @@ -109,13 +110,12 @@ def is_valid_json(string): def get_filters_from_query_string(): """ - Gets a list of filters from the query_strings arg,value pairs. - :example: /datafiles?limit=10&where={"DATASET_ID":2} -> [{"limit":10}, {"where":{"DATASET_ID":10}}] + Gets a list of filters from the query_strings arg,value pairs, and returns a list of QueryFilter Objects :return: The list of filters """ log.info(" Getting filters from query string") filters = [] for arg in request.args: for value in request.args.getlist(arg): - filters.append({arg: json.loads(value)}) + filters.append(QueryFilterFactory.get_query_filter({arg: json.loads(value)})) return filters