Skip to content

Commit

Permalink
#223: Remove underscore prefixes
Browse files Browse the repository at this point in the history
- `_get_field()` is the only function in that class that's used internally, so I've remvoed the underscores from the other functions as I'm not sure I've used them correctly
  • Loading branch information
MRichards99 committed May 17, 2021
1 parent f16cb15 commit 0a202d3
Show file tree
Hide file tree
Showing 2 changed files with 17 additions and 17 deletions.
22 changes: 11 additions & 11 deletions datagateway_api/common/database/filters.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,14 +34,14 @@ def __init__(self):
"""
The `distinct_join_flag` tracks if JOINs need to be added to the query - on a
distinct filter, if there's no unrelated fields (i.e. no fields with a
`related_depth` of 1), adding JOINs to the query (using `_add_query_join()`)
`related_depth` of 1), adding JOINs to the query (using `add_query_join()`)
will result in a `sqlalchemy.exc.InvalidRequestError`
"""
self.field = None
self.related_field = None
self.related_related_field = None

def _extract_filter_fields(self, field):
def extract_filter_fields(self, field):
"""
Extract the related fields names and put them into separate variables
Expand Down Expand Up @@ -73,7 +73,7 @@ def _extract_filter_fields(self, field):
else:
raise ValueError(f"Maximum related depth exceeded. {field}'s depth > 3")

def _add_query_join(self, query):
def add_query_join(self, query):
"""
Adds any required JOINs to the query if any related fields have been used in the
filter
Expand All @@ -92,7 +92,7 @@ def _add_query_join(self, query):
included_table = get_entity_object_from_name(self.field)
query.base_query = query.base_query.join(included_table)

def _get_entity_model_for_filter(self, query):
def get_entity_model_for_filter(self, query):
"""
Fetches the appropriate entity model based on the contents of the instance
variables of this class
Expand Down Expand Up @@ -126,11 +126,11 @@ def __init__(self, field, value, operation):
WhereFilter.__init__(self, field, value, operation)
DatabaseFilterUtilities.__init__(self)

self._extract_filter_fields(field)
self.extract_filter_fields(field)

def apply_filter(self, query):
self._add_query_join(query)
field = self._get_entity_model_for_filter(query)
self.add_query_join(query)
field = self.get_entity_model_for_filter(query)

if self.operation == "eq":
query.base_query = query.base_query.filter(field == self.value)
Expand Down Expand Up @@ -167,8 +167,8 @@ def apply_filter(self, query):
try:
distinct_fields = []
for field_name in self.fields:
self._extract_filter_fields(field_name)
distinct_fields.append(self._get_entity_model_for_filter(query))
self.extract_filter_fields(field_name)
distinct_fields.append(self.get_entity_model_for_filter(query))

# Base query must be set to a DISTINCT query before adding JOINs - if these
# actions are done in the opposite order, the JOINs will overwrite the
Expand All @@ -180,8 +180,8 @@ def apply_filter(self, query):
)

for field_name in self.fields:
self._extract_filter_fields(field_name)
self._add_query_join(query)
self.extract_filter_fields(field_name)
self.add_query_join(query)
except AttributeError:
raise FilterError("Bad field requested")

Expand Down
12 changes: 6 additions & 6 deletions test/db/test_database_filter_utilities.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,7 +30,7 @@ class TestDatabaseFilterUtilities:
)
def test_valid_extract_filter_fields(self, input_field, expected_fields):
test_utility = DatabaseFilterUtilities()
test_utility._extract_filter_fields(input_field)
test_utility.extract_filter_fields(input_field)

assert test_utility.field == expected_fields[0]
assert test_utility.related_field == expected_fields[1]
Expand All @@ -40,7 +40,7 @@ def test_invalid_extract_filter_fields(self):
test_utility = DatabaseFilterUtilities()

with pytest.raises(ValueError):
test_utility._extract_filter_fields(
test_utility.extract_filter_fields(
"user.investigationUsers.investigation.summary",
)

Expand All @@ -60,7 +60,7 @@ def test_valid_add_query_join(
table = get_entity_object_from_name("Investigation")

test_utility = DatabaseFilterUtilities()
test_utility._extract_filter_fields(input_field)
test_utility.extract_filter_fields(input_field)

expected_query = ReadQuery(table)
if test_utility.related_related_field:
Expand All @@ -79,7 +79,7 @@ def test_valid_add_query_join(
expected_table = table

with ReadQuery(table) as test_query:
test_utility._add_query_join(test_query)
test_utility.add_query_join(test_query)

# Check the JOIN has been applied
assert str(test_query.base_query) == str(expected_query.base_query)
Expand All @@ -98,7 +98,7 @@ def test_valid_get_entity_model_for_filter(self, input_field):
table = get_entity_object_from_name("Investigation")

test_utility = DatabaseFilterUtilities()
test_utility._extract_filter_fields(input_field)
test_utility.extract_filter_fields(input_field)

if test_utility.related_related_field:
expected_table = get_entity_object_from_name(test_utility.related_field)
Expand All @@ -108,7 +108,7 @@ def test_valid_get_entity_model_for_filter(self, input_field):
expected_table = table

with ReadQuery(table) as test_query:
output_field = test_utility._get_entity_model_for_filter(test_query)
output_field = test_utility.get_entity_model_for_filter(test_query)

# Check the output is correct
field_name_to_fetch = input_field.split(".")[-1]
Expand Down

0 comments on commit 0a202d3

Please sign in to comment.