From 0a202d3d739aceac0e07e3cc06533753bc9da5ed Mon Sep 17 00:00:00 2001 From: Matthew Richards Date: Mon, 17 May 2021 10:59:05 +0000 Subject: [PATCH] #223: Remove underscore prefixes - `_get_field()` is the only function in that class that's used internally, so I've remvoed the underscores from the other functions as I'm not sure I've used them correctly --- datagateway_api/common/database/filters.py | 22 +++++++++++----------- test/db/test_database_filter_utilities.py | 12 ++++++------ 2 files changed, 17 insertions(+), 17 deletions(-) diff --git a/datagateway_api/common/database/filters.py b/datagateway_api/common/database/filters.py index c3c72142..7b773555 100644 --- a/datagateway_api/common/database/filters.py +++ b/datagateway_api/common/database/filters.py @@ -34,14 +34,14 @@ def __init__(self): """ The `distinct_join_flag` tracks if JOINs need to be added to the query - on a distinct filter, if there's no unrelated fields (i.e. no fields with a - `related_depth` of 1), adding JOINs to the query (using `_add_query_join()`) + `related_depth` of 1), adding JOINs to the query (using `add_query_join()`) will result in a `sqlalchemy.exc.InvalidRequestError` """ self.field = None self.related_field = None self.related_related_field = None - def _extract_filter_fields(self, field): + def extract_filter_fields(self, field): """ Extract the related fields names and put them into separate variables @@ -73,7 +73,7 @@ def _extract_filter_fields(self, field): else: raise ValueError(f"Maximum related depth exceeded. {field}'s depth > 3") - def _add_query_join(self, query): + def add_query_join(self, query): """ Adds any required JOINs to the query if any related fields have been used in the filter @@ -92,7 +92,7 @@ def _add_query_join(self, query): included_table = get_entity_object_from_name(self.field) query.base_query = query.base_query.join(included_table) - def _get_entity_model_for_filter(self, query): + def get_entity_model_for_filter(self, query): """ Fetches the appropriate entity model based on the contents of the instance variables of this class @@ -126,11 +126,11 @@ def __init__(self, field, value, operation): WhereFilter.__init__(self, field, value, operation) DatabaseFilterUtilities.__init__(self) - self._extract_filter_fields(field) + self.extract_filter_fields(field) def apply_filter(self, query): - self._add_query_join(query) - field = self._get_entity_model_for_filter(query) + self.add_query_join(query) + field = self.get_entity_model_for_filter(query) if self.operation == "eq": query.base_query = query.base_query.filter(field == self.value) @@ -167,8 +167,8 @@ def apply_filter(self, query): try: distinct_fields = [] for field_name in self.fields: - self._extract_filter_fields(field_name) - distinct_fields.append(self._get_entity_model_for_filter(query)) + self.extract_filter_fields(field_name) + distinct_fields.append(self.get_entity_model_for_filter(query)) # Base query must be set to a DISTINCT query before adding JOINs - if these # actions are done in the opposite order, the JOINs will overwrite the @@ -180,8 +180,8 @@ def apply_filter(self, query): ) for field_name in self.fields: - self._extract_filter_fields(field_name) - self._add_query_join(query) + self.extract_filter_fields(field_name) + self.add_query_join(query) except AttributeError: raise FilterError("Bad field requested") diff --git a/test/db/test_database_filter_utilities.py b/test/db/test_database_filter_utilities.py index 7a19630d..8321bbf8 100644 --- a/test/db/test_database_filter_utilities.py +++ b/test/db/test_database_filter_utilities.py @@ -30,7 +30,7 @@ class TestDatabaseFilterUtilities: ) def test_valid_extract_filter_fields(self, input_field, expected_fields): test_utility = DatabaseFilterUtilities() - test_utility._extract_filter_fields(input_field) + test_utility.extract_filter_fields(input_field) assert test_utility.field == expected_fields[0] assert test_utility.related_field == expected_fields[1] @@ -40,7 +40,7 @@ def test_invalid_extract_filter_fields(self): test_utility = DatabaseFilterUtilities() with pytest.raises(ValueError): - test_utility._extract_filter_fields( + test_utility.extract_filter_fields( "user.investigationUsers.investigation.summary", ) @@ -60,7 +60,7 @@ def test_valid_add_query_join( table = get_entity_object_from_name("Investigation") test_utility = DatabaseFilterUtilities() - test_utility._extract_filter_fields(input_field) + test_utility.extract_filter_fields(input_field) expected_query = ReadQuery(table) if test_utility.related_related_field: @@ -79,7 +79,7 @@ def test_valid_add_query_join( expected_table = table with ReadQuery(table) as test_query: - test_utility._add_query_join(test_query) + test_utility.add_query_join(test_query) # Check the JOIN has been applied assert str(test_query.base_query) == str(expected_query.base_query) @@ -98,7 +98,7 @@ def test_valid_get_entity_model_for_filter(self, input_field): table = get_entity_object_from_name("Investigation") test_utility = DatabaseFilterUtilities() - test_utility._extract_filter_fields(input_field) + test_utility.extract_filter_fields(input_field) if test_utility.related_related_field: expected_table = get_entity_object_from_name(test_utility.related_field) @@ -108,7 +108,7 @@ def test_valid_get_entity_model_for_filter(self, input_field): expected_table = table with ReadQuery(table) as test_query: - output_field = test_utility._get_entity_model_for_filter(test_query) + output_field = test_utility.get_entity_model_for_filter(test_query) # Check the output is correct field_name_to_fetch = input_field.split(".")[-1]