Skip to content

Commit

Permalink
#119: Allow entity objects to be fetched from plural field names
Browse files Browse the repository at this point in the history
- This feature will be needed when there's camelCase field names for related entities, e.g. for user input on include filters
- This change also moves the location of get_entity_object_from_name() to prevent any issues with circular imports. Since the function no longer makes use of globals(), there's no requirement for the function to be in models.py
  • Loading branch information
MRichards99 committed Feb 1, 2021
1 parent 5b6f055 commit f543007
Show file tree
Hide file tree
Showing 5 changed files with 40 additions and 33 deletions.
20 changes: 10 additions & 10 deletions datagateway_api/common/database/backend.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,9 +19,9 @@
requires_session_id,
update_row_from_id,
)
from datagateway_api.common.database.models import EntityHelper, SESSION
from datagateway_api.common.database.models import SESSION
from datagateway_api.common.exceptions import AuthenticationError
from datagateway_api.common.helpers import queries_records
from datagateway_api.common.helpers import get_entity_object_from_name, queries_records


log = logging.getLogger()
Expand Down Expand Up @@ -63,49 +63,49 @@ def logout(self, session_id):
@requires_session_id
@queries_records
def get_with_filters(self, session_id, entity_type, filters):
table = EntityHelper.get_entity_object_from_name(entity_type)
table = get_entity_object_from_name(entity_type)
return get_rows_by_filter(table, filters)

@requires_session_id
@queries_records
def create(self, session_id, entity_type, data):
table = EntityHelper.get_entity_object_from_name(entity_type)
table = get_entity_object_from_name(entity_type)
return create_rows_from_json(table, data)

@requires_session_id
@queries_records
def update(self, session_id, entity_type, data):
table = EntityHelper.get_entity_object_from_name(entity_type)
table = get_entity_object_from_name(entity_type)
return patch_entities(table, data)

@requires_session_id
@queries_records
def get_one_with_filters(self, session_id, entity_type, filters):
table = EntityHelper.get_entity_object_from_name(entity_type)
table = get_entity_object_from_name(entity_type)
return get_first_filtered_row(table, filters)

@requires_session_id
@queries_records
def count_with_filters(self, session_id, entity_type, filters):
table = EntityHelper.get_entity_object_from_name(entity_type)
table = get_entity_object_from_name(entity_type)
return get_filtered_row_count(table, filters)

@requires_session_id
@queries_records
def get_with_id(self, session_id, entity_type, id_):
table = EntityHelper.get_entity_object_from_name(entity_type)
table = get_entity_object_from_name(entity_type)
return get_row_by_id(table, id_).to_dict()

@requires_session_id
@queries_records
def delete_with_id(self, session_id, entity_type, id_):
table = EntityHelper.get_entity_object_from_name(entity_type)
table = get_entity_object_from_name(entity_type)
return delete_row_by_id(table, id_)

@requires_session_id
@queries_records
def update_with_id(self, session_id, entity_type, id_, data):
table = EntityHelper.get_entity_object_from_name(entity_type)
table = get_entity_object_from_name(entity_type)
return update_row_from_id(table, id_, data)

@requires_session_id
Expand Down
4 changes: 2 additions & 2 deletions datagateway_api/common/database/filters.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@
SkipFilter,
WhereFilter,
)
from datagateway_api.common.helpers import get_entity_object_from_name


log = logging.getLogger()
Expand Down Expand Up @@ -65,9 +66,8 @@ def apply_filter(self, query):
included_included_table,
)
field = getattr(included_included_table, self.included_included_field)

elif self.included_field:
included_table = getattr(models, self.field)
included_table = get_entity_object_from_name(self.field)
query.base_query = query.base_query.join(included_table)
field = getattr(included_table, self.included_field)

Expand Down
21 changes: 1 addition & 20 deletions datagateway_api/common/database/models.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,7 @@
from sqlalchemy.orm import relationship
from sqlalchemy.orm.collections import InstrumentedList

from datagateway_api.common.exceptions import ApiError, DatabaseError, FilterError
from datagateway_api.common.exceptions import DatabaseError, FilterError

Base = declarative_base()

Expand Down Expand Up @@ -58,25 +58,6 @@ class EntityHelper(ABC):
EntityHelper class that contains methods to be shared across all entities
"""

@staticmethod
def get_entity_object_from_name(entity_name):
"""
From an entity name, this function gets a Python version of that entity for the
database backend
:param entity_name: Name of the entity to fetch a version from this model
:type entity_name: :class:`str`
:return: Object of the entity requested (e.g.
:class:`datagateway_api.common.database.models.INVESTIGATIONINSTRUMENT`)
:raises: KeyError: If an entity model cannot be found as a class in this model
"""
try:
return globals()[entity_name.upper()]
except KeyError:
raise ApiError(
f"Entity class cannot be found, missing class for {entity_name}",
)

def to_dict(self):
"""
Turns the columns and values of an entity into a dictionary
Expand Down
25 changes: 25 additions & 0 deletions datagateway_api/common/helpers.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@
from flask_restful import reqparse
from sqlalchemy.exc import IntegrityError

from datagateway_api.common.database import models
from datagateway_api.common.exceptions import (
ApiError,
AuthenticationError,
Expand Down Expand Up @@ -101,3 +102,27 @@ def get_filters_from_query_string():
return filters
except Exception as e:
raise FilterError(e)


def get_entity_object_from_name(entity_name):
"""
From an entity name, this function gets a Python version of that entity for the
database backend
:param entity_name: Name of the entity to fetch a version from this model
:type entity_name: :class:`str`
:return: Object of the entity requested (e.g.
:class:`datagateway_api.common.database.models.INVESTIGATIONINSTRUMENT`)
:raises: KeyError: If an entity model cannot be found as a class in this model
"""
try:
# If a plural is given, fetch the singular field name
if entity_name[-1] == "s":
entity_name = entity_name[0].upper() + entity_name[1:]
entity_name = endpoints[entity_name]

return getattr(models, entity_name.upper())
except KeyError:
raise ApiError(
f"Entity class cannot be found, missing class for {entity_name}",
)
3 changes: 2 additions & 1 deletion datagateway_api/src/resources/entities/entity_map.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@

from sqlalchemy.inspection import inspect

from datagateway_api.common.helpers import get_entity_object_from_name
from datagateway_api.src.resources.entities.entity_endpoint_dict import endpoints


Expand Down Expand Up @@ -36,7 +37,7 @@ def create_entity_models():
for endpoint in endpoints:
params = {}
required = []
endpoint_table = EntityHelper.get_entity_object_from_name(endpoints[endpoint])
endpoint_table = get_entity_object_from_name(endpoints[endpoint])
endpoint_inspection = inspect(endpoint_table)
for column in endpoint_inspection.columns:
# Needed to ensure camelCase field names are used, rather than SNAKE_CASE
Expand Down

0 comments on commit f543007

Please sign in to comment.