diff --git a/test/test_get_entity_object.py b/test/test_get_entity_object.py new file mode 100644 index 00000000..19834a2f --- /dev/null +++ b/test/test_get_entity_object.py @@ -0,0 +1,28 @@ +import pytest + +from datagateway_api.common.database.models import FACILITY, INVESTIGATION, JOB +from datagateway_api.common.exceptions import ApiError +from datagateway_api.common.helpers import get_entity_object_from_name + + +class TestGetEntityObject: + @pytest.mark.parametrize( + "entity_name, expected_object_type", + [ + pytest.param( + "investigation", type(INVESTIGATION), id="singular entity name", + ), + pytest.param("jobs", type(JOB), id="plural entity name, 's' added"), + pytest.param( + "facilities", type(FACILITY), id="plural entity name, 'y' to 'ies'", + ), + ], + ) + def test_valid_get_entity_object_from_name(self, entity_name, expected_object_type): + database_entity = get_entity_object_from_name(entity_name) + + assert type(database_entity) == expected_object_type + + def test_invalid_get_entity_object_from_name(self): + with pytest.raises(ApiError): + get_entity_object_from_name("Application1234s")