diff --git a/common/database_helpers.py b/common/database_helpers.py index 085f6d8b..04316bec 100644 --- a/common/database_helpers.py +++ b/common/database_helpers.py @@ -247,7 +247,7 @@ def update_row_from_id(table, id, new_values): row = get_row_by_id(table, id) update_query = UpdateQuery(table, row, new_values) update_query.execute_query() - + def get_rows_by_filter(table, filters): """ @@ -256,68 +256,21 @@ def get_rows_by_filter(table, filters): :param filters: The list of filters to be applied :return: A list of the rows returned in dictionary form """ - is_limited = False - session = get_icat_db_session() - base_query = session.query(table) - includes_relation = False + query = ReadQuery(table) for query_filter in filters: if len(query_filter) == 0: pass - elif list(query_filter)[0].lower() == "where": - for key in query_filter: - where_part = query_filter[key] - for k in where_part: - column = getattr(table, k.upper()) - base_query = base_query.filter(column.in_([where_part[k]])) - elif list(query_filter)[0].lower() == "order": - for key in query_filter: - field = query_filter[key].split(" ")[0] - direction = query_filter[key].split(" ")[1] - # Limit then order, or order then limit - if is_limited: - if direction.upper() == "ASC": - base_query = base_query.from_self().order_by(asc(getattr(table, field))) - elif direction.upper() == "DESC": - base_query = base_query.from_self().order_by(desc(getattr(table, field))) - else: - raise BadFilterError(f" Bad filter given, filter: {query_filter}") - else: - if direction.upper() == "ASC": - base_query = base_query.order_by(asc(getattr(table, field))) - elif direction.upper() == "DESC": - base_query = base_query.order_by(desc(getattr(table, field))) - else: - raise BadFilterError(f" Bad filter given, filter: {query_filter}") - - elif list(query_filter)[0].lower() == "skip": - for key in query_filter: - skip = query_filter[key] - base_query = base_query.offset(skip) - - elif list(query_filter)[0].lower() == "limit": - is_limited = True - for key in query_filter: - query_limit = query_filter[key] - base_query = base_query.limit(query_limit) - elif list(query_filter)[0].lower() == "include": - includes_relation = True - else: - raise BadFilterError(f"Invalid filters provided received {filters}") - - results = base_query.all() - # check if include was provided, then add included results - if includes_relation: - log.info(" Closing DB session") + QueryFilterFactory.get_query_filter(query_filter).apply_filter(query) + results = query.get_all_results() + if query.include_related_entities: for query_filter in filters: - if list(query_filter)[0] == "include": + if list(query_filter)[0].lower() == "include": return list(map(lambda x: x.to_nested_dict(query_filter["include"]), results)) - - log.info(" Closing DB session") - session.close() return list(map(lambda x: x.to_dict(), results)) + def get_filtered_row_count(table, filters): """ returns the count of the rows that match a given filter in a given table