Skip to content

Commit

Permalink
Merge pull request #16 from ral-facilities/2_add_include_and_order_fi…
Browse files Browse the repository at this point in the history
…lters

Add filtering
  • Loading branch information
keiranjprice101 authored Jul 23, 2019
2 parents 457630e + ed6ebd8 commit 27cfed0
Show file tree
Hide file tree
Showing 5 changed files with 156 additions and 95 deletions.
87 changes: 63 additions & 24 deletions common/database_helpers.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,9 @@
import datetime
import logging

from sqlalchemy import create_engine
from sqlalchemy import create_engine, asc, desc
from sqlalchemy.orm import sessionmaker
from sqlalchemy.orm.collections import InstrumentedList

from common.constants import Constants
from common.exceptions import MissingRecordError, BadFilterError, BadRequestError
Expand Down Expand Up @@ -71,7 +72,7 @@ def get_row_by_id(table, id):
session.close()
return result
session.close()
raise MissingRecordError()
raise MissingRecordError(f" Could not find record in {table.__tablename__} with ID: {id}")


def delete_row_by_id(table, id):
Expand All @@ -90,7 +91,7 @@ def delete_row_by_id(table, id):
session.close()
return
session.close()
raise MissingRecordError()
raise MissingRecordError(f" Could not find record in {table.__tablename__} with ID: {id}")


def update_row_from_id(table, id, new_values):
Expand All @@ -110,39 +111,77 @@ def update_row_from_id(table, id, new_values):
session.close()
return
session.close()
raise MissingRecordError()
raise MissingRecordError(f" Could not find record in {table.__tablename__} with ID: {id}")


def get_rows_by_filter(table, filters):
"""
Given a list of filters supplied in json format, returns entities that match the filters from the given table
:param table: The table to checked
:param filters: The list of filters to be applied
:return: A list of the rows returned in dictionary form
"""
is_limited = False
session = get_icat_db_session()
base_query = session.query(table)
for filter in filters:
if len(filter) == 0:
includes_relation = False
for query_filter in filters:
if len(query_filter) == 0:
pass
elif list(filter)[0].lower() == "where":
for key in filter:
where_part = filter[key]
elif list(query_filter)[0].lower() == "where":
for key in query_filter:
where_part = query_filter[key]
for k in where_part:
column = getattr(table, k.upper())
base_query = base_query.filter(column.in_([where_part[k]]))

elif list(filter)[0].lower() == "order":
base_query.order() # do something probably not .order
elif list(filter)[0].lower() == "skip":
for key in filter:
skip = filter[key]
elif list(query_filter)[0].lower() == "order":
for key in query_filter:
field = query_filter[key].split(" ")[0]
direction = query_filter[key].split(" ")[1]
# Limit then order, or order then limit
if is_limited:
if direction.upper() == "ASC":
base_query = base_query.from_self().order_by(asc(getattr(table, field)))
elif direction.upper() == "DESC":
base_query = base_query.from_self().order_by(desc(getattr(table, field)))
else:
raise BadFilterError(f" Bad filter given, filter: {query_filter}")
else:
if direction.upper() == "ASC":
base_query = base_query.order_by(asc(getattr(table, field)))
elif direction.upper() == "DESC":
base_query = base_query.order_by(desc(getattr(table, field)))
else:
raise BadFilterError(f" Bad filter given, filter: {query_filter}")

elif list(query_filter)[0].lower() == "skip":
for key in query_filter:
skip = query_filter[key]
base_query = base_query.offset(skip)
elif list(filter)[0].lower() == "include":
base_query.include() # do something probably not .include
elif list(filter)[0].lower() == "limit":
for key in filter:
limit = filter[key]
base_query = base_query.limit(limit)

elif list(query_filter)[0].lower() == "limit":
is_limited = True
for key in query_filter:
query_limit = query_filter[key]
base_query = base_query.limit(query_limit)
elif list(query_filter)[0].lower() == "include":
includes_relation = True

else:
raise BadFilterError()
raise BadFilterError(f"Invalid filters provided received {filters}")

results = base_query.all()
# check if include was provided, then add included results
if includes_relation:
log.info(" Closing DB session")
for query_filter in filters:
if list(query_filter)[0] == "include":
return list(map(lambda x: x.to_nested_dict(query_filter["include"]), results))


log.info(" Closing DB session")
session.close()
return list(map(lambda x: x.to_dict(), base_query.all()))
return list(map(lambda x: x.to_dict(), results))


def get_filtered_row_count(table, filters):
Expand Down Expand Up @@ -190,6 +229,6 @@ def patch_entities(table, json_list):
result = get_row_by_id(table, entity[key])
results.append(result)
if len(results) == 0:
raise BadRequestError()
raise BadRequestError(f" Bad request made, request: {json_list}")

return results
16 changes: 5 additions & 11 deletions common/exceptions.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,26 +4,20 @@


class ApiError(Exception):
def __init__(self):
log.info(" ApiError(): An error has been raised.")
pass


class MissingRecordError(ApiError):
def __init__(self):
log.info(" MissingRecordError(): Record not found, DB session Closed")

pass


class BadFilterError(ApiError):
def __init__(self):
log.info(" BadFilterError(): Bad filter supplied")
pass


class AuthenticationError(ApiError):
def __init__(self):
log.info(" AuthenticationError(): Error authenticating consumer")
pass


class BadRequestError(ApiError):
def __init__(self):
log.info(" BadRequestError(): Bad request by Consumer")
pass
4 changes: 2 additions & 2 deletions common/helpers.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,7 +33,7 @@ def wrapper_requires_session(*args, **kwargs):
log.info(" Consumer authenticated")
return method(*args, **kwargs)
else:
log.info(" Closing DB session")
log.info(" Could not authenticate consumer, closing DB session")
session.close()
return "Forbidden", 403
except AuthenticationError:
Expand Down Expand Up @@ -87,7 +87,7 @@ def get_session_id_from_auth_header():
if auth_header == "":
return ""
if len(auth_header) != 2 or auth_header[0] != "Bearer":
raise AuthenticationError()
raise AuthenticationError(f" Could not authenticate consumer with auth header {auth_header}")
return auth_header[1]


Expand Down
Loading

0 comments on commit 27cfed0

Please sign in to comment.