diff --git a/README.md b/README.md index b3379dd6..5c77a581 100644 --- a/README.md +++ b/README.md @@ -285,8 +285,8 @@ pre-commit install Depending on the backend you want to use (either `db` or `python_icat`, more details about backends [here](#backends)) the connection URL for the backend needs to be set. These are set in `config.json` (an example file is provided in the base directory of -this repository). While both `DB_URL` and `ICAT_URL` should have values assigned to them -(for best practice), `DB_URL` will only be used for the database backend, and `ICAT_URL` +this repository). While both `db_url` and `icat_url` should have values assigned to them +(for best practice), `db_url` will only be used for the database backend, and `icat_url` will only be used for the Python ICAT backend. Copy `config.json.example` to `config.json` and set the values as needed. If you need to create an instance of ICAT, there are a number of markdown-formatted tutorials that can be found on the @@ -770,7 +770,7 @@ flags `-s` or `--seed` for the seed, and `-y` or `--years` for the number of yea example: `python -m util.icat_db_generator -s 4 -y 10` Would set the seed to 4 and generate 10 years of data. -This uses code from the API's Database Backend, so a suitable `DB_URL` should be +This uses code from the API's Database Backend, so a suitable `db_url` should be configured in `config.json`. diff --git a/datagateway_api/common/config.py b/datagateway_api/common/config.py index 5956fd19..80665aab 100644 --- a/datagateway_api/common/config.py +++ b/datagateway_api/common/config.py @@ -1,3 +1,4 @@ +from enum import Enum import json import logging from pathlib import Path @@ -9,17 +10,90 @@ log = logging.getLogger() +class APIConfigOptions(Enum): + """ + Class to map config keys to variables in Python - implemented for ease of + development (IntelliSense in IDEs) + """ + + BACKEND = "backend" + CLIENT_CACHE_SIZE = "client_cache_size" + CLIENT_POOL_INIT_SIZE = "client_pool_init_size" + CLIENT_POOL_MAX_SIZE = "client_pool_max_size" + DB_URL = "db_url" + DEBUG_MODE = "debug_mode" + FLASK_RELOADER = "flask_reloader" + GENERATE_SWAGGER = "generate_swagger" + HOST = "host" + ICAT_CHECK_CERT = "icat_check_cert" + ICAT_URL = "icat_url" + LOG_LEVEL = "log_level" + LOG_LOCATION = "log_location" + PORT = "port" + TEST_MECHANISM = "test_mechanism" + TEST_USER_CREDENTIALS = "test_user_credentials" + + class Config(object): def __init__(self, path=Path(__file__).parent.parent / "config.json"): self.path = path with open(self.path) as target: - self.config = json.load(target) + self._config = json.load(target) + + self._check_config_items_exist() + + def _check_config_items_exist(self): + """ + A function to check that all config options exist before getting too far into + the setup of the API. This check takes the backend into account, meaning only + the config options for the backend used is required + + Config options used for testing are not checked here as they should only be used + during tests, not in the typical running of the API + + If a config option is missing, this will be picked up in `get_config_value()` by + exiting the application + """ + # These keys are non-backend specific and therefore are mandatory for all uses + config_keys = [ + APIConfigOptions.BACKEND, + APIConfigOptions.DEBUG_MODE, + APIConfigOptions.FLASK_RELOADER, + APIConfigOptions.GENERATE_SWAGGER, + APIConfigOptions.HOST, + APIConfigOptions.LOG_LEVEL, + APIConfigOptions.LOG_LOCATION, + APIConfigOptions.PORT, + ] + + if self.get_config_value(APIConfigOptions.BACKEND) == "python_icat": + icat_backend_specific_config_keys = [ + APIConfigOptions.CLIENT_CACHE_SIZE, + APIConfigOptions.CLIENT_POOL_INIT_SIZE, + APIConfigOptions.CLIENT_POOL_MAX_SIZE, + APIConfigOptions.ICAT_CHECK_CERT, + APIConfigOptions.ICAT_URL, + ] + config_keys.extend(icat_backend_specific_config_keys) + elif self.get_config_value(APIConfigOptions.BACKEND) == "db": + db_backend_specific_config_keys = [APIConfigOptions.DB_URL] + config_keys.extend(db_backend_specific_config_keys) + + for key in config_keys: + self.get_config_value(key) + + def get_config_value(self, config_key): + """ + Given a config key, the corresponding config value is returned - def get_backend_type(self): + :param config_key: Enum of a configuration key that's in `config.json` + :type config_key: :class:`APIConfigOptions` + :return: Config value of the given key + """ try: - return self.config["backend"] + return self._config[config_key.value] except KeyError: - sys.exit("Missing config value, backend") + sys.exit(f"Missing config value: {config_key.value}") def set_backend_type(self, backend_type): """ @@ -32,97 +106,7 @@ def set_backend_type(self, backend_type): type must be fetched. This must be done using this module (rather than directly importing and checking the Flask app's config) to avoid circular import issues. """ - self.config["backend"] = backend_type - - def get_client_cache_size(self): - try: - return self.config["client_cache_size"] - except KeyError: - sys.exit("Missing config value, client_cache_size") - - def get_client_pool_init_size(self): - try: - return self.config["client_pool_init_size"] - except KeyError: - sys.exit("Missing config value, client_pool_init_size") - - def get_client_pool_max_size(self): - try: - return self.config["client_pool_max_size"] - except KeyError: - sys.exit("Missing config value, client_pool_max_size") - - def get_db_url(self): - try: - return self.config["DB_URL"] - except KeyError: - sys.exit("Missing config value, DB_URL") - - def is_flask_reloader(self): - try: - return self.config["flask_reloader"] - except KeyError: - sys.exit("Missing config value, flask_reloader") - - def get_icat_url(self): - try: - return self.config["ICAT_URL"] - except KeyError: - sys.exit("Missing config value, ICAT_URL") - - def get_icat_check_cert(self): - try: - return self.config["icat_check_cert"] - except KeyError: - sys.exit("Missing config value, icat_check_cert") - - def get_log_level(self): - try: - return self.config["log_level"] - except KeyError: - sys.exit("Missing config value, log_level") - - def get_log_location(self): - try: - return self.config["log_location"] - except KeyError: - sys.exit("Missing config value, log_location") - - def is_debug_mode(self): - try: - return self.config["debug_mode"] - except KeyError: - sys.exit("Missing config value, debug_mode") - - def is_generate_swagger(self): - try: - return self.config["generate_swagger"] - except KeyError: - sys.exit("Missing config value, generate_swagger") - - def get_host(self): - try: - return self.config["host"] - except KeyError: - sys.exit("Missing config value, host") - - def get_port(self): - try: - return self.config["port"] - except KeyError: - sys.exit("Missing config value, port") - - def get_test_user_credentials(self): - try: - return self.config["test_user_credentials"] - except KeyError: - sys.exit("Missing config value, test_user_credentials") - - def get_test_mechanism(self): - try: - return self.config["test_mechanism"] - except KeyError: - sys.exit("Missing config value, test_mechanism") + self._config["backend"] = backend_type def get_icat_properties(self): """ @@ -130,8 +114,14 @@ def get_icat_properties(self): requires the client object to be authenticated which may not always be the case when requesting these properties, hence a HTTP request is sent as an alternative """ - properties_url = f"{config.get_icat_url()}/icat/properties" - r = requests.request("GET", properties_url, verify=config.get_icat_check_cert()) + properties_url = ( + f"{config.get_config_value(APIConfigOptions.ICAT_URL)}/icat/properties" + ) + r = requests.request( + "GET", + properties_url, + verify=config.get_config_value(APIConfigOptions.ICAT_CHECK_CERT), + ) icat_properties = r.json() return icat_properties diff --git a/datagateway_api/common/constants.py b/datagateway_api/common/constants.py index f3b63956..a642c0b6 100644 --- a/datagateway_api/common/constants.py +++ b/datagateway_api/common/constants.py @@ -1,10 +1,6 @@ from datetime import datetime -from datagateway_api.common.config import config - class Constants: - DATABASE_URL = config.get_db_url() PYTHON_ICAT_DISTNCT_CONDITION = "!= null" - ICAT_PROPERTIES = config.get_icat_properties() TEST_MOD_CREATE_DATETIME = datetime(2000, 1, 1) diff --git a/datagateway_api/common/icat/filters.py b/datagateway_api/common/icat/filters.py index 1b68c09b..58ee92b9 100644 --- a/datagateway_api/common/icat/filters.py +++ b/datagateway_api/common/icat/filters.py @@ -1,6 +1,6 @@ import logging -from datagateway_api.common.constants import Constants +from datagateway_api.common.config import config from datagateway_api.common.exceptions import FilterError from datagateway_api.common.filters import ( DistinctFieldFilter, @@ -171,7 +171,7 @@ def __init__(self, skip_value): super().__init__(skip_value) def apply_filter(self, query): - icat_properties = Constants.ICAT_PROPERTIES + icat_properties = config.get_icat_properties() icat_set_limit(query, self.skip_value, icat_properties["maxEntities"]) diff --git a/datagateway_api/common/icat/icat_client_pool.py b/datagateway_api/common/icat/icat_client_pool.py index dad667e2..cd8c2ff0 100644 --- a/datagateway_api/common/icat/icat_client_pool.py +++ b/datagateway_api/common/icat/icat_client_pool.py @@ -3,7 +3,7 @@ from icat.client import Client from object_pool import ObjectPool -from datagateway_api.common.config import config +from datagateway_api.common.config import APIConfigOptions, config log = logging.getLogger() @@ -12,7 +12,10 @@ class ICATClient(Client): """Wrapper class to allow an object pool of client objects to be created""" def __init__(self): - super().__init__(config.get_icat_url(), checkCert=config.get_icat_check_cert()) + super().__init__( + config.get_config_value(APIConfigOptions.ICAT_URL), + checkCert=config.get_config_value(APIConfigOptions.ICAT_CHECK_CERT), + ) # When clients are cleaned up, sessions won't be logged out self.autoLogout = False @@ -33,8 +36,8 @@ def create_client_pool(): return ObjectPool( ICATClient, - min_init=config.get_client_pool_init_size(), - max_capacity=config.get_client_pool_max_size(), + min_init=config.get_config_value(APIConfigOptions.CLIENT_POOL_INIT_SIZE), + max_capacity=config.get_config_value(APIConfigOptions.CLIENT_POOL_MAX_SIZE), max_reusable=0, expires=0, ) diff --git a/datagateway_api/common/icat/lru_cache.py b/datagateway_api/common/icat/lru_cache.py index 7d3cedfe..441c9b6e 100644 --- a/datagateway_api/common/icat/lru_cache.py +++ b/datagateway_api/common/icat/lru_cache.py @@ -2,7 +2,7 @@ from cachetools.lru import LRUCache -from datagateway_api.common.config import config +from datagateway_api.common.config import APIConfigOptions, config log = logging.getLogger() @@ -19,7 +19,9 @@ class ExtendedLRUCache(LRUCache): """ def __init__(self): - super().__init__(maxsize=config.get_client_cache_size()) + super().__init__( + maxsize=config.get_config_value(APIConfigOptions.CLIENT_CACHE_SIZE), + ) def popitem(self): key, client = super().popitem() diff --git a/datagateway_api/common/logger_setup.py b/datagateway_api/common/logger_setup.py index 3c727995..e89ab6b8 100644 --- a/datagateway_api/common/logger_setup.py +++ b/datagateway_api/common/logger_setup.py @@ -1,9 +1,9 @@ import logging.config from pathlib import Path -from datagateway_api.common.config import config +from datagateway_api.common.config import APIConfigOptions, config -LOG_FILE_NAME = Path(config.get_log_location()) +LOG_FILE_NAME = Path(config.get_config_value(APIConfigOptions.LOG_LOCATION)) logger_config = { "version": 1, "formatters": { @@ -14,7 +14,7 @@ }, "handlers": { "default": { - "level": config.get_log_level(), + "level": config.get_config_value(APIConfigOptions.LOG_LEVEL), "formatter": "default", "class": "logging.handlers.RotatingFileHandler", "filename": LOG_FILE_NAME, @@ -22,7 +22,10 @@ "backupCount": 10, }, }, - "root": {"level": config.get_log_level(), "handlers": ["default"]}, + "root": { + "level": config.get_config_value(APIConfigOptions.LOG_LEVEL), + "handlers": ["default"], + }, } diff --git a/datagateway_api/common/query_filter_factory.py b/datagateway_api/common/query_filter_factory.py index e4e90308..078becf3 100644 --- a/datagateway_api/common/query_filter_factory.py +++ b/datagateway_api/common/query_filter_factory.py @@ -1,6 +1,6 @@ import logging -from datagateway_api.common.config import config +from datagateway_api.common.config import APIConfigOptions, config from datagateway_api.common.exceptions import ( ApiError, FilterError, @@ -27,7 +27,7 @@ def get_query_filter(request_filter): :raises FilterError: If the filter name is not recognised """ - backend_type = config.get_backend_type() + backend_type = config.get_config_value(APIConfigOptions.BACKEND) if backend_type == "db": from datagateway_api.common.database.filters import ( DatabaseDistinctFieldFilter as DistinctFieldFilter, diff --git a/datagateway_api/config.json.example b/datagateway_api/config.json.example index be15be6b..68137ce9 100644 --- a/datagateway_api/config.json.example +++ b/datagateway_api/config.json.example @@ -3,9 +3,9 @@ "client_cache_size": 5, "client_pool_init_size": 2, "client_pool_max_size": 5, - "DB_URL": "mysql+pymysql://icatdbuser:icatdbuserpw@localhost:3306/icatdb", + "db_url": "mysql+pymysql://icatdbuser:icatdbuserpw@localhost:3306/icatdb", "flask_reloader": false, - "ICAT_URL": "https://localhost:8181", + "icat_url": "https://localhost:8181", "icat_check_cert": false, "log_level": "WARN", "log_location": "/home/runner/work/datagateway-api/datagateway-api/logs.log", diff --git a/datagateway_api/src/api_start_utils.py b/datagateway_api/src/api_start_utils.py index acc156bb..25cbe69a 100644 --- a/datagateway_api/src/api_start_utils.py +++ b/datagateway_api/src/api_start_utils.py @@ -8,8 +8,7 @@ from flask_swagger_ui import get_swaggerui_blueprint from datagateway_api.common.backends import create_backend -from datagateway_api.common.config import config -from datagateway_api.common.constants import Constants +from datagateway_api.common.config import APIConfigOptions, config from datagateway_api.common.database.helpers import db from datagateway_api.common.icat.icat_client_pool import create_client_pool from datagateway_api.src.resources.entities.entity_endpoint import ( @@ -65,10 +64,12 @@ def create_app_infrastructure(flask_app): backend_type = flask_app.config["TEST_BACKEND"] config.set_backend_type(backend_type) except KeyError: - backend_type = config.get_backend_type() + backend_type = config.get_config_value(APIConfigOptions.BACKEND) if backend_type == "db": - flask_app.config["SQLALCHEMY_DATABASE_URI"] = Constants.DATABASE_URL + flask_app.config["SQLALCHEMY_DATABASE_URI"] = config.get_config_value( + APIConfigOptions.DB_URL, + ) flask_app.config["SQLALCHEMY_TRACK_MODIFICATIONS"] = False db.init_app(flask_app) @@ -82,7 +83,7 @@ def create_api_endpoints(flask_app, api, spec): backend_type = flask_app.config["TEST_BACKEND"] config.set_backend_type(backend_type) except KeyError: - backend_type = config.get_backend_type() + backend_type = config.get_config_value(APIConfigOptions.BACKEND) backend = create_backend(backend_type) @@ -164,7 +165,7 @@ def create_api_endpoints(flask_app, api, spec): def openapi_config(spec): # Reorder paths (e.g. get, patch, post) so openapi.yaml only changes when there's a # change to the Swagger docs, rather than changing on each startup - if config.is_generate_swagger(): + if config.get_config_value(APIConfigOptions.GENERATE_SWAGGER): log.debug("Reordering OpenAPI docs to alphabetical order") for entity_data in spec._paths.values(): for endpoint_name in sorted(entity_data.keys()): diff --git a/datagateway_api/src/main.py b/datagateway_api/src/main.py index e4ad1c3f..3814be15 100644 --- a/datagateway_api/src/main.py +++ b/datagateway_api/src/main.py @@ -2,7 +2,7 @@ from flask import Flask -from datagateway_api.common.config import config +from datagateway_api.common.config import APIConfigOptions, config from datagateway_api.common.logger_setup import setup_logger from datagateway_api.src.api_start_utils import ( create_api_endpoints, @@ -23,8 +23,8 @@ if __name__ == "__main__": app.run( - host=config.get_host(), - port=config.get_port(), - debug=config.is_debug_mode(), - use_reloader=config.is_flask_reloader(), + host=config.get_config_value(APIConfigOptions.HOST), + port=config.get_config_value(APIConfigOptions.PORT), + debug=config.get_config_value(APIConfigOptions.DEBUG_MODE), + use_reloader=config.get_config_value(APIConfigOptions.FLASK_RELOADER), ) diff --git a/test/icat/conftest.py b/test/icat/conftest.py index 48198dad..4f464ec1 100644 --- a/test/icat/conftest.py +++ b/test/icat/conftest.py @@ -7,7 +7,7 @@ from icat.query import Query import pytest -from datagateway_api.common.config import config +from datagateway_api.common.config import APIConfigOptions, config from datagateway_api.src.api_start_utils import ( create_api_endpoints, create_app_infrastructure, @@ -18,8 +18,14 @@ @pytest.fixture(scope="package") def icat_client(): - client = Client(config.get_icat_url(), checkCert=config.get_icat_check_cert()) - client.login(config.get_test_mechanism(), config.get_test_user_credentials()) + client = Client( + config.get_config_value(APIConfigOptions.ICAT_URL), + checkCert=config.get_config_value(APIConfigOptions.ICAT_CHECK_CERT), + ) + client.login( + config.get_config_value(APIConfigOptions.TEST_MECHANISM), + config.get_config_value(APIConfigOptions.TEST_USER_CREDENTIALS), + ) return client diff --git a/test/icat/test_lru_cache.py b/test/icat/test_lru_cache.py index b400317e..3f27d603 100644 --- a/test/icat/test_lru_cache.py +++ b/test/icat/test_lru_cache.py @@ -3,7 +3,7 @@ from cachetools import cached from icat.client import Client -from datagateway_api.common.config import config +from datagateway_api.common.config import APIConfigOptions, config from datagateway_api.common.icat.icat_client_pool import create_client_pool from datagateway_api.common.icat.lru_cache import ExtendedLRUCache @@ -11,13 +11,16 @@ class TestLRUCache: def test_valid_cache_creation(self): test_cache = ExtendedLRUCache() - assert test_cache.maxsize == config.get_client_cache_size() + assert test_cache.maxsize == config.get_config_value( + APIConfigOptions.CLIENT_CACHE_SIZE, + ) def test_valid_popitem(self): test_cache = ExtendedLRUCache() test_pool = create_client_pool() test_client = Client( - config.get_icat_url(), checkCert=config.get_icat_check_cert(), + config.get_config_value(APIConfigOptions.ICAT_URL), + checkCert=config.get_config_value(APIConfigOptions.ICAT_CHECK_CERT), ) test_cache.popitem = MagicMock(side_effect=test_cache.popitem) @@ -26,7 +29,9 @@ def test_valid_popitem(self): def get_cached_client(cache_number, client_pool): return test_client - for cache_number in range(config.get_client_cache_size() + 1): + for cache_number in range( + config.get_config_value(APIConfigOptions.CLIENT_CACHE_SIZE) + 1, + ): get_cached_client(cache_number, test_pool) assert test_cache.popitem.called diff --git a/test/icat/test_session_handling.py b/test/icat/test_session_handling.py index 14fbe77f..b1884393 100644 --- a/test/icat/test_session_handling.py +++ b/test/icat/test_session_handling.py @@ -3,7 +3,7 @@ from icat.client import Client import pytest -from datagateway_api.common.config import config +from datagateway_api.common.config import APIConfigOptions, config from datagateway_api.common.icat.filters import PythonICATWhereFilter @@ -26,10 +26,11 @@ def test_get_valid_session_details( assert time_diff_minutes < 120 and time_diff_minutes >= 118 # Check username is correct - assert ( - session_details.json["username"] == f"{config.get_test_mechanism()}/" - f"{config.get_test_user_credentials()['username']}" - ) + test_mechanism = config.get_config_value(APIConfigOptions.TEST_MECHANISM) + test_username = config.get_config_value(APIConfigOptions.TEST_USER_CREDENTIALS)[ + "username" + ] + assert session_details.json["username"] == f"{test_mechanism}/{test_username}" # Check session ID matches the header from the request assert ( @@ -72,16 +73,26 @@ def test_refresh_session(self, valid_icat_credentials_header, flask_test_app_ica [ pytest.param( { - "username": config.get_test_user_credentials()["username"], - "password": config.get_test_user_credentials()["password"], - "mechanism": config.get_test_mechanism(), + "username": config.get_config_value( + APIConfigOptions.TEST_USER_CREDENTIALS, + )["username"], + "password": config.get_config_value( + APIConfigOptions.TEST_USER_CREDENTIALS, + )["password"], + "mechanism": config.get_config_value( + APIConfigOptions.TEST_MECHANISM, + ), }, id="Normal request body", ), pytest.param( { - "username": config.get_test_user_credentials()["username"], - "password": config.get_test_user_credentials()["password"], + "username": config.get_config_value( + APIConfigOptions.TEST_USER_CREDENTIALS, + )["username"], + "password": config.get_config_value( + APIConfigOptions.TEST_USER_CREDENTIALS, + )["password"], }, id="Missing mechanism in request body", ), @@ -110,7 +121,9 @@ def test_valid_login( { "username": "Invalid Username", "password": "InvalidPassword", - "mechanism": config.get_test_mechanism(), + "mechanism": config.get_config_value( + APIConfigOptions.TEST_MECHANISM, + ), }, 403, id="Invalid credentials", @@ -126,8 +139,14 @@ def test_invalid_login( assert login_response.status_code == expected_response_code def test_valid_logout(self, flask_test_app_icat): - client = Client(config.get_icat_url(), checkCert=config.get_icat_check_cert()) - client.login(config.get_test_mechanism(), config.get_test_user_credentials()) + client = Client( + config.get_config_value(APIConfigOptions.ICAT_URL), + checkCert=config.get_config_value(APIConfigOptions.ICAT_CHECK_CERT), + ) + client.login( + config.get_config_value(APIConfigOptions.TEST_MECHANISM), + config.get_config_value(APIConfigOptions.TEST_USER_CREDENTIALS), + ) creds_header = {"Authorization": f"Bearer {client.sessionId}"} logout_response = flask_test_app_icat.delete("/sessions", headers=creds_header) diff --git a/test/test_config.py b/test/test_config.py index ccb07c6a..a5fbbcd5 100644 --- a/test/test_config.py +++ b/test/test_config.py @@ -3,189 +3,53 @@ import pytest -from datagateway_api.common.config import Config +from datagateway_api.common.config import APIConfigOptions, Config @pytest.fixture() -def valid_config(): +def test_config(): return Config( path=Path(__file__).parent.parent / "datagateway_api" / "config.json.example", ) -@pytest.fixture() -def invalid_config(): - blank_config_file = tempfile.NamedTemporaryFile(mode="w+", suffix=".json") - blank_config_file.write("{}") - blank_config_file.seek(0) - - return Config(path=blank_config_file.name) - - -class TestGetBackendType: - def test_valid_backend_type(self, valid_config): - backend_type = valid_config.get_backend_type() +class TestConfig: + def test_valid_get_config_value(self, test_config): + backend_type = test_config.get_config_value(APIConfigOptions.BACKEND) assert backend_type == "db" - def test_invalid_backend_type(self, invalid_config): - with pytest.raises(SystemExit): - invalid_config.get_backend_type() - - -class TestGetClientCacheSize: - def test_valid_client_cache_size(self, valid_config): - cache_size = valid_config.get_client_cache_size() - assert cache_size == 5 - - def test_invalid_client_cache_size(self, invalid_config): - with pytest.raises(SystemExit): - invalid_config.get_client_cache_size() - - -class TestGetClientPoolInitSize: - def test_valid_client_pool_init_size(self, valid_config): - pool_init_size = valid_config.get_client_pool_init_size() - assert pool_init_size == 2 - - def test_invalid_client_cache_size(self, invalid_config): - with pytest.raises(SystemExit): - invalid_config.get_client_pool_init_size() - - -class TestGetClientPoolMaxSize: - def test_valid_client_pool_init_size(self, valid_config): - pool_max_size = valid_config.get_client_pool_max_size() - assert pool_max_size == 5 - - def test_invalid_client_cache_size(self, invalid_config): - with pytest.raises(SystemExit): - invalid_config.get_client_pool_max_size() - - -class TestGetDBURL: - def test_valid_db_url(self, valid_config): - db_url = valid_config.get_db_url() - assert db_url == "mysql+pymysql://icatdbuser:icatdbuserpw@localhost:3306/icatdb" - - def test_invalid_db_url(self, invalid_config): + def test_invalid_get_config_value(self, test_config): + del test_config._config["backend"] with pytest.raises(SystemExit): - invalid_config.get_db_url() - - -class TestIsFlaskReloader: - def test_valid_flask_reloader(self, valid_config): - flask_reloader = valid_config.is_flask_reloader() - assert flask_reloader is False - - def test_invalid_flask_reloader(self, invalid_config): - with pytest.raises(SystemExit): - invalid_config.is_flask_reloader() - - -class TestICATURL: - def test_valid_icat_url(self, valid_config): - icat_url = valid_config.get_icat_url() - assert icat_url == "https://localhost:8181" - - def test_invalid_icat_url(self, invalid_config): - with pytest.raises(SystemExit): - invalid_config.get_icat_url() - - -class TestICATCheckCert: - def test_valid_icat_check_cert(self, valid_config): - icat_check_cert = valid_config.get_icat_check_cert() - assert icat_check_cert is False + test_config.get_config_value(APIConfigOptions.BACKEND) - def test_invalid_icat_check_cert(self, invalid_config): - with pytest.raises(SystemExit): - invalid_config.get_icat_check_cert() - - -class TestGetLogLevel: - def test_valid_log_level(self, valid_config): - log_level = valid_config.get_log_level() - assert log_level == "WARN" - - def test_invalid_log_level(self, invalid_config): - with pytest.raises(SystemExit): - invalid_config.get_log_level() - - -class TestGetLogLocation: - def test_valid_log_location(self, valid_config): - log_location = valid_config.get_log_location() - assert ( - log_location == "/home/runner/work/datagateway-api/datagateway-api/logs.log" - ) - - def test_invalid_log_location(self, invalid_config): - with pytest.raises(SystemExit): - invalid_config.get_log_location() - - -class TestIsDebugMode: - def test_valid_debug_mode(self, valid_config): - debug_mode = valid_config.is_debug_mode() - assert debug_mode is False - - def test_invalid_debug_mode(self, invalid_config): - with pytest.raises(SystemExit): - invalid_config.is_debug_mode() - - -class TestIsGenerateSwagger: - def test_valid_generate_swagger(self, valid_config): - generate_swagger = valid_config.is_generate_swagger() - assert generate_swagger is False - - def test_invalid_generate_swagger(self, invalid_config): - with pytest.raises(SystemExit): - invalid_config.is_generate_swagger() - - -class TestGetHost: - def test_valid_host(self, valid_config): - host = valid_config.get_host() - assert host == "127.0.0.1" - - def test_invalid_host(self, invalid_config): - with pytest.raises(SystemExit): - invalid_config.get_host() - - -class TestGetPort: - def test_valid_port(self, valid_config): - port = valid_config.get_port() - assert port == "5000" - - def test_invalid_port(self, invalid_config): - with pytest.raises(SystemExit): - invalid_config.get_port() + @pytest.mark.parametrize( + "backend_type", + [ + pytest.param("python_icat", id="Python ICAT Backend"), + pytest.param("db", id="Database Backend"), + ], + ) + def test_valid_config_items_exist(self, test_config, backend_type): + test_config._config["backend"] = backend_type + # Just want to check no SysExit's, so no assert is needed + test_config._check_config_items_exist() -class TestGetTestUserCredentials: - def test_valid_test_user_credentials(self, valid_config): - test_user_credentials = valid_config.get_test_user_credentials() - assert test_user_credentials == {"username": "root", "password": "pw"} + def test_invalid_config_items_exist(self): + blank_config_file = tempfile.NamedTemporaryFile(mode="w+", suffix=".json") + blank_config_file.write("{}") + blank_config_file.seek(0) - def test_invalid_test_user_credentials(self, invalid_config): with pytest.raises(SystemExit): - invalid_config.get_test_user_credentials() + Config(path=blank_config_file.name) + def test_valid_set_backend_type(self, test_config): + test_config.set_backend_type("backend_name_changed") -class TestGetTestMechanism: - def test_valid_test_mechanism(self, valid_config): - test_mechanism = valid_config.get_test_mechanism() - assert test_mechanism == "simple" - - def test_invalid_test_mechanism(self, invalid_config): - with pytest.raises(SystemExit): - invalid_config.get_test_mechanism() - + assert test_config._config["backend"] == "backend_name_changed" -class TestGetICATProperties: - def test_valid_icat_properties(self, valid_config): + def test_valid_icat_properties(self, test_config): example_icat_properties = { "maxEntities": 10000, "lifetimeMinutes": 120, @@ -199,6 +63,6 @@ def test_valid_icat_properties(self, valid_config): "containerType": "Glassfish", } - icat_properties = valid_config.get_icat_properties() + icat_properties = test_config.get_icat_properties() # Values could vary across versions, less likely that keys will assert icat_properties.keys() == example_icat_properties.keys() diff --git a/util/icat_db_generator.py b/util/icat_db_generator.py index 5a5268a5..a4a14303 100644 --- a/util/icat_db_generator.py +++ b/util/icat_db_generator.py @@ -9,7 +9,7 @@ from sqlalchemy.orm import scoped_session, sessionmaker from sqlalchemy.pool import QueuePool -from datagateway_api.common.constants import Constants +from datagateway_api.common.config import APIConfigOptions, config from datagateway_api.common.database import models parser = argparse.ArgumentParser() @@ -38,7 +38,10 @@ engine = create_engine( - Constants.DATABASE_URL, poolclass=QueuePool, pool_size=100, max_overflow=0, + config.get_config_value(APIConfigOptions.DB_URL), + poolclass=QueuePool, + pool_size=100, + max_overflow=0, ) session_factory = sessionmaker(engine) session = scoped_session(session_factory)()