Skip to content

Commit

Permalink
#142: Move DB filters to their own directory
Browse files Browse the repository at this point in the history
- This change also moves the database implementation of the backend and the helper functions to a specific database folder. The same will be done with the python_icat versions of these
- Unit tests still pass when using the database backend
  • Loading branch information
MRichards99 committed Jul 29, 2020
1 parent 01bd4bb commit ea45b66
Show file tree
Hide file tree
Showing 10 changed files with 158 additions and 147 deletions.
2 changes: 1 addition & 1 deletion common/backends.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
from common.database_backend import DatabaseBackend
from common.database.backend import DatabaseBackend
from common.python_icat_backend import PythonICATBackend
from common.backend import Backend
from common.config import config
Expand Down
5 changes: 2 additions & 3 deletions common/database_backend.py → common/database/backend.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,8 @@
from common.backend import Backend
from common.database_helpers import get_facility_cycles_for_instrument, get_facility_cycles_for_instrument_count, \
from common.database.helpers import get_facility_cycles_for_instrument, get_facility_cycles_for_instrument_count, \
get_investigations_for_instrument_in_facility_cycle, get_investigations_for_instrument_in_facility_cycle_count, \
get_rows_by_filter, create_rows_from_json, patch_entities, get_row_by_id, insert_row_into_table, \
delete_row_by_id, update_row_from_id, get_filtered_row_count, get_first_filtered_row
from common.database_helpers import requires_session_id
delete_row_by_id, update_row_from_id, get_filtered_row_count, get_first_filtered_row, requires_session_id
from common.helpers import queries_records
from common.models.db_models import SESSION
import uuid
Expand Down
125 changes: 125 additions & 0 deletions common/database/filters.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,125 @@
from common.filters import WhereFilter, DistinctFieldFilter, OrderFilter, SkipFilter, LimitFilter, \
IncludeFilter

class DatabaseWhereFilter(WhereFilter):
precedence = 1

def __init__(self, field, value, operation):
self.field = field
self.included_field = None
self.included_included_field = None
self._set_filter_fields()
self.value = value
self.operation = operation

def _set_filter_fields(self):
if self.field.count(".") == 1:
self.included_field = self.field.split(".")[1]
self.field = self.field.split(".")[0]

if self.field.count(".") == 2:
self.included_included_field = self.field.split(".")[2]
self.included_field = self.field.split(".")[1]
self.field = self.field.split(".")[0]

def apply_filter(self, query):
try:
field = getattr(query.table, self.field)
except AttributeError:
raise BadFilterError(f"Bad WhereFilter requested")

if self.included_included_field:
included_table = getattr(db_models, self.field)
included_included_table = getattr(db_models, self.included_field)
query.base_query = query.base_query.join(
included_table).join(included_included_table)
field = getattr(included_included_table,
self.included_included_field)

elif self.included_field:
included_table = getattr(db_models, self.field)
query.base_query = query.base_query.join(included_table)
field = getattr(included_table, self.included_field)

if self.operation == "eq":
query.base_query = query.base_query.filter(field == self.value)
elif self.operation == "like":
query.base_query = query.base_query.filter(
field.like(f"%{self.value}%"))
elif self.operation == "lte":
query.base_query = query.base_query.filter(field <= self.value)
elif self.operation == "gte":
query.base_query = query.base_query.filter(field >= self.value)
elif self.operation == "in":
query.base_query = query.base_query.filter(field.in_(self.value))
else:
raise BadFilterError(
f" Bad operation given to where filter. operation: {self.operation}")


class DatabaseDistinctFieldFilter(DistinctFieldFilter):
precedence = 0

def __init__(self, fields):
# This allows single string distinct filters
self.fields = fields if type(fields) is list else [fields]

def apply_filter(self, query):
query.is_distinct_fields_query = True
try:
self.fields = [getattr(query.table, field)
for field in self.fields]
except AttributeError:
raise BadFilterError("Bad field requested")
query.base_query = query.session.query(*self.fields).distinct()


class DatabaseOrderFilter(OrderFilter):
precedence = 2

def __init__(self, field, direction):
self.field = field
self.direction = direction

def apply_filter(self, query):
if self.direction.upper() == "ASC":
query.base_query = query.base_query.order_by(
asc(self.field.upper()))
elif self.direction.upper() == "DESC":
query.base_query = query.base_query.order_by(
desc(self.field.upper()))
else:
raise BadFilterError(f" Bad filter: {self.direction}")


class DatabaseSkipFilter(SkipFilter):
precedence = 3

def __init__(self, skip_value):
self.skip_value = skip_value

def apply_filter(self, query):
query.base_query = query.base_query.offset(self.skip_value)


class DatabaseLimitFilter(LimitFilter):
precedence = 4

def __init__(self, limit_value):
self.limit_value = limit_value

def apply_filter(self, query):
query.base_query = query.base_query.limit(self.limit_value)


class DatabaseIncludeFilter(IncludeFilter):
precedence = 5

def __init__(self, included_filters):
self.included_filters = included_filters["include"]

def apply_filter(self, query):
if not query.include_related_entities:
query.include_related_entities = True
else:
raise MultipleIncludeError()
147 changes: 12 additions & 135 deletions common/database_helpers.py → common/database/helpers.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,18 @@
INVESTIGATIONINSTRUMENT, FACILITY, SESSION
from common.session_manager import session_manager
from common.filters import FilterOrderHandler
from common.config import config

backend_type = config.get_backend_type()
if backend_type == "db":
from common.database.filters import DatabaseWhereFilter as WhereFilter, DatabaseDistinctFieldFilter as DistinctFieldFilter, \
DatabaseOrderFilter as OrderFilter, DatabaseSkipFilter as SkipFilter, DatabaseLimitFilter as LimitFilter, \
DatabaseIncludeFilter as IncludeFilter
elif backend_type == "python_icat":
pass
else:
# TODO - Check this works
raise ApiError("Cannot select which implementation of filters to import, check the config file has a valid backend type")

log = logging.getLogger()

Expand Down Expand Up @@ -165,141 +177,6 @@ def execute_query(self):
self.commit_changes()


class QueryFilter(ABC):
@property
@abstractmethod
def precedence(self):
pass

@abstractmethod
def apply_filter(self, query):
pass


class WhereFilter(QueryFilter):
precedence = 1

def __init__(self, field, value, operation):
self.field = field
self.included_field = None
self.included_included_field = None
self._set_filter_fields()
self.value = value
self.operation = operation

def _set_filter_fields(self):
if self.field.count(".") == 1:
self.included_field = self.field.split(".")[1]
self.field = self.field.split(".")[0]

if self.field.count(".") == 2:
self.included_included_field = self.field.split(".")[2]
self.included_field = self.field.split(".")[1]
self.field = self.field.split(".")[0]

def apply_filter(self, query):
try:
field = getattr(query.table, self.field)
except AttributeError:
raise BadFilterError(f"Bad WhereFilter requested")

if self.included_included_field:
included_table = getattr(db_models, self.field)
included_included_table = getattr(db_models, self.included_field)
query.base_query = query.base_query.join(
included_table).join(included_included_table)
field = getattr(included_included_table,
self.included_included_field)

elif self.included_field:
included_table = getattr(db_models, self.field)
query.base_query = query.base_query.join(included_table)
field = getattr(included_table, self.included_field)

if self.operation == "eq":
query.base_query = query.base_query.filter(field == self.value)
elif self.operation == "like":
query.base_query = query.base_query.filter(
field.like(f"%{self.value}%"))
elif self.operation == "lte":
query.base_query = query.base_query.filter(field <= self.value)
elif self.operation == "gte":
query.base_query = query.base_query.filter(field >= self.value)
elif self.operation == "in":
query.base_query = query.base_query.filter(field.in_(self.value))
else:
raise BadFilterError(
f" Bad operation given to where filter. operation: {self.operation}")


class DistinctFieldFilter(QueryFilter):
precedence = 0

def __init__(self, fields):
# This allows single string distinct filters
self.fields = fields if type(fields) is list else [fields]

def apply_filter(self, query):
query.is_distinct_fields_query = True
try:
self.fields = [getattr(query.table, field)
for field in self.fields]
except AttributeError:
raise BadFilterError("Bad field requested")
query.base_query = query.session.query(*self.fields).distinct()


class OrderFilter(QueryFilter):
precedence = 2

def __init__(self, field, direction):
self.field = field
self.direction = direction

def apply_filter(self, query):
if self.direction.upper() == "ASC":
query.base_query = query.base_query.order_by(
asc(self.field.upper()))
elif self.direction.upper() == "DESC":
query.base_query = query.base_query.order_by(
desc(self.field.upper()))
else:
raise BadFilterError(f" Bad filter: {self.direction}")


class SkipFilter(QueryFilter):
precedence = 3

def __init__(self, skip_value):
self.skip_value = skip_value

def apply_filter(self, query):
query.base_query = query.base_query.offset(self.skip_value)


class LimitFilter(QueryFilter):
precedence = 4

def __init__(self, limit_value):
self.limit_value = limit_value

def apply_filter(self, query):
query.base_query = query.base_query.limit(self.limit_value)


class IncludeFilter(QueryFilter):
precedence = 5

def __init__(self, included_filters):
self.included_filters = included_filters["include"]

def apply_filter(self, query):
if not query.include_related_entities:
query.include_related_entities = True
else:
raise MultipleIncludeError()


class QueryFilterFactory(object):
@staticmethod
def get_query_filter(filter):
Expand Down
2 changes: 1 addition & 1 deletion common/helpers.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@
from flask_restful import reqparse
from sqlalchemy.exc import IntegrityError

from common.database_helpers import QueryFilterFactory
from common.database.helpers import QueryFilterFactory
from common.exceptions import ApiError, AuthenticationError, BadFilterError, BadRequestError, MissingCredentialsError, MissingRecordError, MultipleIncludeError

log = logging.getLogger()
Expand Down
2 changes: 1 addition & 1 deletion src/resources/entities/entity_endpoint.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
from flask import request
from flask_restful import Resource

from common.database_helpers import get_rows_by_filter, create_rows_from_json, patch_entities, get_row_by_id, \
from common.database.helpers import get_rows_by_filter, create_rows_from_json, patch_entities, get_row_by_id, \
delete_row_by_id, update_row_from_id, get_filtered_row_count, get_first_filtered_row
from common.helpers import get_session_id_from_auth_header, get_filters_from_query_string
from common.backends import backend
Expand Down
2 changes: 1 addition & 1 deletion src/resources/non_entities/sessions_endpoints.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@
from flask import request
from flask_restful import Resource, reqparse

from common.database_helpers import insert_row_into_table, delete_row_by_id, get_row_by_id
from common.database.helpers import insert_row_into_table, delete_row_by_id, get_row_by_id
from common.helpers import get_session_id_from_auth_header
from common.models.db_models import SESSION
from common.backends import backend
Expand Down
2 changes: 1 addition & 1 deletion src/resources/table_endpoints/table_endpoints.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
from flask_restful import Resource

from common.database_helpers import get_facility_cycles_for_instrument, get_facility_cycles_for_instrument_count, \
from common.database.helpers import get_facility_cycles_for_instrument, get_facility_cycles_for_instrument_count, \
get_investigations_for_instrument_in_facility_cycle, get_investigations_for_instrument_in_facility_cycle_count
from common.helpers import get_session_id_from_auth_header, get_filters_from_query_string
from common.backends import backend
Expand Down
16 changes: 13 additions & 3 deletions test/test_database_helpers.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,18 @@
from unittest import TestCase

#from common.filters import QueryFilterFactory
from common.database_helpers import OrderFilter, LimitFilter, SkipFilter, WhereFilter, \
IncludeFilter, DistinctFieldFilter, QueryFilterFactory
from common.database.helpers import QueryFilterFactory
from common.config import config

backend_type = config.get_backend_type()
if backend_type == "db":
from common.database.filters import DatabaseWhereFilter as WhereFilter, DatabaseDistinctFieldFilter as DistinctFieldFilter, \
DatabaseOrderFilter as OrderFilter, DatabaseSkipFilter as SkipFilter, DatabaseLimitFilter as LimitFilter, \
DatabaseIncludeFilter as IncludeFilter
elif backend_type == "python_icat":
pass
else:
# TODO - Check this works
raise ApiError("Cannot select which implementation of filters to import, check the config file has a valid backend type")


class TestQueryFilterFactory(TestCase):
Expand Down
2 changes: 1 addition & 1 deletion test/test_helpers.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@

from sqlalchemy.exc import IntegrityError

from common.database_helpers import delete_row_by_id, insert_row_into_table, LimitFilter, DistinctFieldFilter, \
from common.database.helpers import delete_row_by_id, insert_row_into_table, LimitFilter, DistinctFieldFilter, \
IncludeFilter, SkipFilter, WhereFilter, OrderFilter
from common.exceptions import MissingRecordError, BadFilterError, BadRequestError, MissingCredentialsError, \
AuthenticationError
Expand Down

0 comments on commit ea45b66

Please sign in to comment.