Skip to content

Commit

Permalink
Add common package
Browse files Browse the repository at this point in the history
  • Loading branch information
keiranjprice101 committed Jun 11, 2019
1 parent b25e8dc commit b60fb4f
Show file tree
Hide file tree
Showing 7 changed files with 1,100 additions and 0 deletions.
Empty file added common/__init__.py
Empty file.
2 changes: 2 additions & 0 deletions common/constants.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,2 @@
class Constants:
DATABASE_URL = "mysql+pymysql://root:rootpw@localhost:13306/icatdb"
176 changes: 176 additions & 0 deletions common/database_helpers.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,176 @@
import json
import datetime

from sqlalchemy import create_engine
from sqlalchemy.orm import sessionmaker

from common.constants import Constants
from common.exceptions import MissingRecordError, BadFilterError


def get_record_by_id(table, id):
"""
Gets a row from the dummy data credential database
:param table: the table class mapping
:param id: the id to find
:return: the row from the table
"""
session = get_db_session()
result = session.query(table).filter(table.ID == id).first()
if result is not None:
session.close()
return result
session.close()
raise MissingRecordError()


def get_db_session():
"""
Gets a session in the dummy data database, currently used for credentials until Authentication is understood
:return: the dummy data DB session
"""
engine = create_engine("mysql+pymysql://root:root@localhost:3306/icatdummy")
Session = sessionmaker(bind=engine)
session = Session()
return session


def get_icat_db_session():
"""
Gets a session and connects with the ICAT database
:return: the session object
"""
engine = create_engine(Constants.DATABASE_URL)
Session = sessionmaker(bind=engine)
session = Session()
return session


def insert_row_into_table(row):
"""
Insert the given row into its table
:param row: The row to be inserted
"""
session = get_icat_db_session()
session.add(row)
session.commit()
session.close()


def create_row_from_json(table, json):
"""
Given a json dictionary create a row in the table from it
:param table: the table for the row to be inserted into
:param json: the dictionary containing the values
:return: nothing atm
"""
session = get_icat_db_session()
record = table()
record.update_from_dict(json)
record.CREATE_TIME = datetime.datetime.now() # These should probably change
record.CREATE_ID = "user"
record.MOD_TIME = datetime.datetime.now()
record.MOD_ID = "user"
session.add(record)
session.commit()
session.close()


def get_row_by_id(table, id):
"""
Gets the row matching the given ID from the given table, raises MissingRecordError if it can not be found
:param table: the table to be searched
:param id: the id of the record to find
:return: the record retrieved
"""
session = get_icat_db_session()
result = session.query(table).filter(table.ID == id).first()
if result is not None:
session.close()
return result
session.close()
raise MissingRecordError()


def delete_row_by_id(table, id):
"""
Deletes the row matching the given ID from the given table, raises MissingRecordError if it can not be found
:param table: the table to be searched
:param id: the id of the record to delete
"""
session = get_icat_db_session()
result = get_row_by_id(table, id)
if result is not None:
session.delete(result)
session.commit()
session.close()
return
session.close()
raise MissingRecordError()


def update_row_from_id(table, id, new_values):
"""
Updates a record in a table
:param table: The table the record is in
:param id: The id of the record
:param new_values: A JSON string containing what columns are to be updated
"""
session = get_icat_db_session()
record = session.query(table).filter(table.ID == id).first()
if record is not None:
record.update_from_dict(new_values)
session.commit()
session.close()
return
session.close()
raise MissingRecordError()


def get_rows_by_filter(table, filters):
session = get_icat_db_session()
base_query = session.query(table)
for filter in filters:
if list(filter)[0].lower() == "where":
for key in filter:
where_part = filter[key]
for k in where_part:
column = getattr(table, k.upper())
base_query = base_query.filter(column.in_([where_part[k]]))

elif list(filter)[0].lower() == "order":
base_query.order() # do something probably not .order
elif list(filter)[0].lower() == "skip":
for key in filter:
skip = filter[key]
base_query = base_query.offset(skip)
elif list(filter)[0].lower() == "include":
base_query.include() # do something probably not .include
elif list(filter)[0].lower() == "limit":
for key in filter:
limit = filter[key]
base_query = base_query.limit(limit)
else:
raise BadFilterError()
session.close()
return list(map(lambda x: x.to_dict(), base_query.all()))


def get_filtered_row_count(table, filters):
"""
returns the count of the rows that match a given filter in a given table
:param table: the table to be checked
:param filters: the filters to be applied to the query
:return: int: the count of the rows
"""
return len(get_rows_by_filter(table, filters))


def get_first_filtered_row(table, filters):
"""
returns the first row that matches a given filter, in a given table
:param table: the table to be checked
:param filters: the filter to be applied to the query
:return: the first row matching the filter
"""
return get_rows_by_filter(table, filters)[0]
10 changes: 10 additions & 0 deletions common/exceptions.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,10 @@
class ApiError(Exception):
pass


class MissingRecordError(ApiError):
pass


class BadFilterError(ApiError):
pass
86 changes: 86 additions & 0 deletions common/helpers.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,86 @@
import json
from functools import wraps

from flask import request
from flask_restful import reqparse
from sqlalchemy.exc import IntegrityError

from common.database_helpers import get_icat_db_session
from common.exceptions import MissingRecordError, BadFilterError
from common.models.db_models import SESSION


def requires_session_id(method):
"""
Decorator for endpoint resources that makes sure a valid session_id is provided in requests to that endpoint
:param method: The method for the endpoint
:returns a 403, "Forbidden" if a valid session_id is not provided with the request
"""

@wraps(method)
def wrapper_requires_session(*args, **kwargs):
session = get_icat_db_session()
query = session.query(SESSION).filter(
SESSION.ID == get_session_id_from_auth_header()).first()
if query is not None:
return method(*args, **kwargs)
else:
return "Forbidden", 403

return wrapper_requires_session


def queries_records(method):
"""
Decorator for endpoint resources that search for a record in a table
:param method: The method for the endpoint
:return: Will return a 404, "No such record" if a MissingRecordError is caught
"""

@wraps(method)
def wrapper_gets_records(*args, **kwargs):
try:
return method(*args, **kwargs)
except MissingRecordError:
return "No such record in table", 404
except BadFilterError:
return "Invalid filter requested", 400
except ValueError:
return "Bad request", 400
except TypeError:
return "Bad request", 400
except IntegrityError as e:
return "Bad request", 400

return wrapper_gets_records


def get_session_id_from_auth_header():
"""
Gets the sessionID from the Authorization header of a request
:return: String: SessionID
"""
parser = reqparse.RequestParser()
parser.add_argument("Authorization", location="headers")
args = parser.parse_args()
if args["Authorization"] is not None:
return args["Authorization"]
return ""


def is_valid_json(string):
"""
Determines if a string is valid JSON
:param string: The string to be tested
:return: boolean representing if the string is valid JSON
"""
try:
json_object = json.loads(string)
except ValueError:
return False
return True


def get_filters_from_query_string():
filters = request.args.getlist("filter")
return list(map(lambda x: json.loads(x), filters))
Empty file added common/models/__init__.py
Empty file.
Loading

0 comments on commit b60fb4f

Please sign in to comment.