-
Notifications
You must be signed in to change notification settings - Fork 4
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
- Loading branch information
1 parent
b25e8dc
commit b60fb4f
Showing
7 changed files
with
1,100 additions
and
0 deletions.
There are no files selected for viewing
Empty file.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,2 @@ | ||
class Constants: | ||
DATABASE_URL = "mysql+pymysql://root:rootpw@localhost:13306/icatdb" |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,176 @@ | ||
import json | ||
import datetime | ||
|
||
from sqlalchemy import create_engine | ||
from sqlalchemy.orm import sessionmaker | ||
|
||
from common.constants import Constants | ||
from common.exceptions import MissingRecordError, BadFilterError | ||
|
||
|
||
def get_record_by_id(table, id): | ||
""" | ||
Gets a row from the dummy data credential database | ||
:param table: the table class mapping | ||
:param id: the id to find | ||
:return: the row from the table | ||
""" | ||
session = get_db_session() | ||
result = session.query(table).filter(table.ID == id).first() | ||
if result is not None: | ||
session.close() | ||
return result | ||
session.close() | ||
raise MissingRecordError() | ||
|
||
|
||
def get_db_session(): | ||
""" | ||
Gets a session in the dummy data database, currently used for credentials until Authentication is understood | ||
:return: the dummy data DB session | ||
""" | ||
engine = create_engine("mysql+pymysql://root:root@localhost:3306/icatdummy") | ||
Session = sessionmaker(bind=engine) | ||
session = Session() | ||
return session | ||
|
||
|
||
def get_icat_db_session(): | ||
""" | ||
Gets a session and connects with the ICAT database | ||
:return: the session object | ||
""" | ||
engine = create_engine(Constants.DATABASE_URL) | ||
Session = sessionmaker(bind=engine) | ||
session = Session() | ||
return session | ||
|
||
|
||
def insert_row_into_table(row): | ||
""" | ||
Insert the given row into its table | ||
:param row: The row to be inserted | ||
""" | ||
session = get_icat_db_session() | ||
session.add(row) | ||
session.commit() | ||
session.close() | ||
|
||
|
||
def create_row_from_json(table, json): | ||
""" | ||
Given a json dictionary create a row in the table from it | ||
:param table: the table for the row to be inserted into | ||
:param json: the dictionary containing the values | ||
:return: nothing atm | ||
""" | ||
session = get_icat_db_session() | ||
record = table() | ||
record.update_from_dict(json) | ||
record.CREATE_TIME = datetime.datetime.now() # These should probably change | ||
record.CREATE_ID = "user" | ||
record.MOD_TIME = datetime.datetime.now() | ||
record.MOD_ID = "user" | ||
session.add(record) | ||
session.commit() | ||
session.close() | ||
|
||
|
||
def get_row_by_id(table, id): | ||
""" | ||
Gets the row matching the given ID from the given table, raises MissingRecordError if it can not be found | ||
:param table: the table to be searched | ||
:param id: the id of the record to find | ||
:return: the record retrieved | ||
""" | ||
session = get_icat_db_session() | ||
result = session.query(table).filter(table.ID == id).first() | ||
if result is not None: | ||
session.close() | ||
return result | ||
session.close() | ||
raise MissingRecordError() | ||
|
||
|
||
def delete_row_by_id(table, id): | ||
""" | ||
Deletes the row matching the given ID from the given table, raises MissingRecordError if it can not be found | ||
:param table: the table to be searched | ||
:param id: the id of the record to delete | ||
""" | ||
session = get_icat_db_session() | ||
result = get_row_by_id(table, id) | ||
if result is not None: | ||
session.delete(result) | ||
session.commit() | ||
session.close() | ||
return | ||
session.close() | ||
raise MissingRecordError() | ||
|
||
|
||
def update_row_from_id(table, id, new_values): | ||
""" | ||
Updates a record in a table | ||
:param table: The table the record is in | ||
:param id: The id of the record | ||
:param new_values: A JSON string containing what columns are to be updated | ||
""" | ||
session = get_icat_db_session() | ||
record = session.query(table).filter(table.ID == id).first() | ||
if record is not None: | ||
record.update_from_dict(new_values) | ||
session.commit() | ||
session.close() | ||
return | ||
session.close() | ||
raise MissingRecordError() | ||
|
||
|
||
def get_rows_by_filter(table, filters): | ||
session = get_icat_db_session() | ||
base_query = session.query(table) | ||
for filter in filters: | ||
if list(filter)[0].lower() == "where": | ||
for key in filter: | ||
where_part = filter[key] | ||
for k in where_part: | ||
column = getattr(table, k.upper()) | ||
base_query = base_query.filter(column.in_([where_part[k]])) | ||
|
||
elif list(filter)[0].lower() == "order": | ||
base_query.order() # do something probably not .order | ||
elif list(filter)[0].lower() == "skip": | ||
for key in filter: | ||
skip = filter[key] | ||
base_query = base_query.offset(skip) | ||
elif list(filter)[0].lower() == "include": | ||
base_query.include() # do something probably not .include | ||
elif list(filter)[0].lower() == "limit": | ||
for key in filter: | ||
limit = filter[key] | ||
base_query = base_query.limit(limit) | ||
else: | ||
raise BadFilterError() | ||
session.close() | ||
return list(map(lambda x: x.to_dict(), base_query.all())) | ||
|
||
|
||
def get_filtered_row_count(table, filters): | ||
""" | ||
returns the count of the rows that match a given filter in a given table | ||
:param table: the table to be checked | ||
:param filters: the filters to be applied to the query | ||
:return: int: the count of the rows | ||
""" | ||
return len(get_rows_by_filter(table, filters)) | ||
|
||
|
||
def get_first_filtered_row(table, filters): | ||
""" | ||
returns the first row that matches a given filter, in a given table | ||
:param table: the table to be checked | ||
:param filters: the filter to be applied to the query | ||
:return: the first row matching the filter | ||
""" | ||
return get_rows_by_filter(table, filters)[0] |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,10 @@ | ||
class ApiError(Exception): | ||
pass | ||
|
||
|
||
class MissingRecordError(ApiError): | ||
pass | ||
|
||
|
||
class BadFilterError(ApiError): | ||
pass |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,86 @@ | ||
import json | ||
from functools import wraps | ||
|
||
from flask import request | ||
from flask_restful import reqparse | ||
from sqlalchemy.exc import IntegrityError | ||
|
||
from common.database_helpers import get_icat_db_session | ||
from common.exceptions import MissingRecordError, BadFilterError | ||
from common.models.db_models import SESSION | ||
|
||
|
||
def requires_session_id(method): | ||
""" | ||
Decorator for endpoint resources that makes sure a valid session_id is provided in requests to that endpoint | ||
:param method: The method for the endpoint | ||
:returns a 403, "Forbidden" if a valid session_id is not provided with the request | ||
""" | ||
|
||
@wraps(method) | ||
def wrapper_requires_session(*args, **kwargs): | ||
session = get_icat_db_session() | ||
query = session.query(SESSION).filter( | ||
SESSION.ID == get_session_id_from_auth_header()).first() | ||
if query is not None: | ||
return method(*args, **kwargs) | ||
else: | ||
return "Forbidden", 403 | ||
|
||
return wrapper_requires_session | ||
|
||
|
||
def queries_records(method): | ||
""" | ||
Decorator for endpoint resources that search for a record in a table | ||
:param method: The method for the endpoint | ||
:return: Will return a 404, "No such record" if a MissingRecordError is caught | ||
""" | ||
|
||
@wraps(method) | ||
def wrapper_gets_records(*args, **kwargs): | ||
try: | ||
return method(*args, **kwargs) | ||
except MissingRecordError: | ||
return "No such record in table", 404 | ||
except BadFilterError: | ||
return "Invalid filter requested", 400 | ||
except ValueError: | ||
return "Bad request", 400 | ||
except TypeError: | ||
return "Bad request", 400 | ||
except IntegrityError as e: | ||
return "Bad request", 400 | ||
|
||
return wrapper_gets_records | ||
|
||
|
||
def get_session_id_from_auth_header(): | ||
""" | ||
Gets the sessionID from the Authorization header of a request | ||
:return: String: SessionID | ||
""" | ||
parser = reqparse.RequestParser() | ||
parser.add_argument("Authorization", location="headers") | ||
args = parser.parse_args() | ||
if args["Authorization"] is not None: | ||
return args["Authorization"] | ||
return "" | ||
|
||
|
||
def is_valid_json(string): | ||
""" | ||
Determines if a string is valid JSON | ||
:param string: The string to be tested | ||
:return: boolean representing if the string is valid JSON | ||
""" | ||
try: | ||
json_object = json.loads(string) | ||
except ValueError: | ||
return False | ||
return True | ||
|
||
|
||
def get_filters_from_query_string(): | ||
filters = request.args.getlist("filter") | ||
return list(map(lambda x: json.loads(x), filters)) |
Empty file.
Oops, something went wrong.