Skip to content

Commit

Permalink
#223: Add tests for DatabaseFilterUtilities
Browse files Browse the repository at this point in the history
  • Loading branch information
MRichards99 committed May 13, 2021
1 parent 4a79940 commit bc96351
Showing 1 changed file with 104 additions and 0 deletions.
104 changes: 104 additions & 0 deletions test/db/test_database_filter_utilities.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,104 @@
import pytest

from datagateway_api.common.database.filters import DatabaseFilterUtilities
from datagateway_api.common.database.helpers import ReadQuery
from datagateway_api.common.exceptions import FilterError
from datagateway_api.common.helpers import get_entity_object_from_name


class TestDatabaseFilterUtilities:
@pytest.mark.parametrize(
"input_field, expected_fields",
[
pytest.param("name", ("name", None, None), id="Unrelated field"),
pytest.param(
"facility.daysUntilRelease",
("facility", "daysUntilRelease", None),
id="Related field matching ICAT schema name",
),
pytest.param(
"FACILITY.daysUntilRelease",
("FACILITY", "daysUntilRelease", None),
id="Related field matching database format (uppercase)",
),
pytest.param(
"user.investigationUsers.role",
("user", "investigationUsers", "role"),
id="Related related field (2 levels deep)",
),
],
)
def test_valid_extract_filter_fields(self, input_field, expected_fields):
test_utility = DatabaseFilterUtilities()
test_utility._extract_filter_fields(input_field)

assert test_utility.field == expected_fields[0]
assert test_utility.related_field == expected_fields[1]
assert test_utility.related_related_field == expected_fields[2]

def test_invalid_extract_filter_fields(self):
test_utility = DatabaseFilterUtilities()

with pytest.raises(ValueError):
test_utility._extract_filter_fields(
"user.investigationUsers.investigation.summary",
)

@pytest.mark.parametrize(
"input_field",
[
pytest.param("name", id="No related fields"),
pytest.param("facility.daysUntilRelease", id="Related field"),
pytest.param(
"investigationUsers.user.fullName", id="Related related field",
),
],
)
def test_valid_add_query_join(
self, flask_test_app_db, input_field,
):
table = get_entity_object_from_name("Investigation")

test_utility = DatabaseFilterUtilities()
test_utility._extract_filter_fields(input_field)

expected_query = ReadQuery(table)
if test_utility.related_related_field:
expected_table = get_entity_object_from_name(test_utility.related_field)

included_table = get_entity_object_from_name(test_utility.field)
expected_query.base_query = expected_query.base_query.join(
included_table,
).join(expected_table)
elif test_utility.related_field:
expected_table = get_entity_object_from_name(test_utility.field)

expected_query = ReadQuery(table)
expected_query.base_query = expected_query.base_query.join(expected_table)
else:
expected_table = table

with ReadQuery(table) as test_query:
output_field = test_utility._add_query_join(test_query)

# Check the JOIN has been applied
assert str(test_query.base_query) == str(expected_query.base_query)

# Check the output is correct
field_name_to_fetch = input_field.split(".")[-1]
assert output_field == getattr(expected_table, field_name_to_fetch)

def test_valid_get_field(self, flask_test_app_db):
table = get_entity_object_from_name("Investigation")

test_utility = DatabaseFilterUtilities()
field = test_utility._get_field(table, "name")

assert field == table.name

def test_invalid_get_field(self, flask_test_app_db):
table = get_entity_object_from_name("Investigation")

test_utility = DatabaseFilterUtilities()
with pytest.raises(FilterError):
test_utility._get_field(table, "unknown")

0 comments on commit bc96351

Please sign in to comment.