Skip to content

Commit

Permalink
#2: Allow a single include
Browse files Browse the repository at this point in the history
  • Loading branch information
keiranjprice101 committed Jul 2, 2019
1 parent bf00f23 commit 21a6e5f
Showing 1 changed file with 25 additions and 2 deletions.
27 changes: 25 additions & 2 deletions common/database_helpers.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@

from sqlalchemy import create_engine, asc, desc
from sqlalchemy.orm import sessionmaker
from sqlalchemy.orm.collections import InstrumentedList

from common.constants import Constants
from common.exceptions import MissingRecordError, BadFilterError, BadRequestError
Expand Down Expand Up @@ -117,12 +118,13 @@ def get_rows_by_filter(table, filters):
"""
Given a list of filters supplied in json format, returns entities that match the filters from the given table
:param table: The table to checked
:param filters: The filters to be applied
:param filters: The list of filters to be applied
:return: A list of the rows returned in dictionary form
"""
is_limited = False
session = get_icat_db_session()
base_query = session.query(table)
includes_relation = False
for filter in filters:
if list(filter)[0].lower() == "where":
for key in filter:
Expand Down Expand Up @@ -161,13 +163,34 @@ def get_rows_by_filter(table, filters):
for key in filter:
limit = filter[key]
base_query = base_query.limit(limit)
elif list(filter)[0].lower() == "include":
includes_relation = True

else:
raise BadFilterError(f"Invalid filters provided recieved {filters}")

results = base_query.all()
if includes_relation:
included_relationships = []
for filter in filters:
if list(filter)[0] == "include":
included_relationships.append(filter["include"])
included_results = []
for row in results:
for relation in included_relationships:
# Here we check if the included result returns a list of children and if so iterate through them and
# add them to the results.
if isinstance(getattr(row, relation.upper()),InstrumentedList):
for i in getattr(row, relation.upper()):
included_results.append(i)
else:
included_results.append(getattr(row, relation.upper()))
results.extend(included_results)


log.info(" Closing DB session")
session.close()
return list(map(lambda x: x.to_dict(), base_query.all()))
return list(map(lambda x: x.to_dict(), results))


def get_filtered_row_count(table, filters):
Expand Down

0 comments on commit 21a6e5f

Please sign in to comment.