Skip to content

Commit

Permalink
update get_rows_by_filter
Browse files Browse the repository at this point in the history
  • Loading branch information
keiranjprice101 committed Jul 29, 2019
1 parent d874231 commit 27a977c
Showing 1 changed file with 7 additions and 54 deletions.
61 changes: 7 additions & 54 deletions common/database_helpers.py
Original file line number Diff line number Diff line change
Expand Up @@ -247,7 +247,7 @@ def update_row_from_id(table, id, new_values):
row = get_row_by_id(table, id)
update_query = UpdateQuery(table, row, new_values)
update_query.execute_query()


def get_rows_by_filter(table, filters):
"""
Expand All @@ -256,68 +256,21 @@ def get_rows_by_filter(table, filters):
:param filters: The list of filters to be applied
:return: A list of the rows returned in dictionary form
"""
is_limited = False
session = get_icat_db_session()
base_query = session.query(table)
includes_relation = False
query = ReadQuery(table)
for query_filter in filters:
if len(query_filter) == 0:
pass
elif list(query_filter)[0].lower() == "where":
for key in query_filter:
where_part = query_filter[key]
for k in where_part:
column = getattr(table, k.upper())
base_query = base_query.filter(column.in_([where_part[k]]))
elif list(query_filter)[0].lower() == "order":
for key in query_filter:
field = query_filter[key].split(" ")[0]
direction = query_filter[key].split(" ")[1]
# Limit then order, or order then limit
if is_limited:
if direction.upper() == "ASC":
base_query = base_query.from_self().order_by(asc(getattr(table, field)))
elif direction.upper() == "DESC":
base_query = base_query.from_self().order_by(desc(getattr(table, field)))
else:
raise BadFilterError(f" Bad filter given, filter: {query_filter}")
else:
if direction.upper() == "ASC":
base_query = base_query.order_by(asc(getattr(table, field)))
elif direction.upper() == "DESC":
base_query = base_query.order_by(desc(getattr(table, field)))
else:
raise BadFilterError(f" Bad filter given, filter: {query_filter}")

elif list(query_filter)[0].lower() == "skip":
for key in query_filter:
skip = query_filter[key]
base_query = base_query.offset(skip)

elif list(query_filter)[0].lower() == "limit":
is_limited = True
for key in query_filter:
query_limit = query_filter[key]
base_query = base_query.limit(query_limit)
elif list(query_filter)[0].lower() == "include":
includes_relation = True

else:
raise BadFilterError(f"Invalid filters provided received {filters}")

results = base_query.all()
# check if include was provided, then add included results
if includes_relation:
log.info(" Closing DB session")
QueryFilterFactory.get_query_filter(query_filter).apply_filter(query)
results = query.get_all_results()
if query.include_related_entities:
for query_filter in filters:
if list(query_filter)[0] == "include":
if list(query_filter)[0].lower() == "include":
return list(map(lambda x: x.to_nested_dict(query_filter["include"]), results))

log.info(" Closing DB session")
session.close()
return list(map(lambda x: x.to_dict(), results))



def get_filtered_row_count(table, filters):
"""
returns the count of the rows that match a given filter in a given table
Expand Down

0 comments on commit 27a977c

Please sign in to comment.