Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Refactor config to be aware of the backend in use #222

Merged
merged 14 commits into from
Apr 26, 2021
Merged
Show file tree
Hide file tree
Changes from 11 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 3 additions & 3 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -284,8 +284,8 @@ pre-commit install
Depending on the backend you want to use (either `db` or `python_icat`, more details
about backends [here](#backends)) the connection URL for the backend needs to be set.
These are set in `config.json` (an example file is provided in the base directory of
this repository). While both `DB_URL` and `ICAT_URL` should have values assigned to them
(for best practice), `DB_URL` will only be used for the database backend, and `ICAT_URL`
this repository). While both `db_url` and `icat_url` should have values assigned to them
(for best practice), `db_url` will only be used for the database backend, and `icat_url`
will only be used for the Python ICAT backend. Copy `config.json.example` to
`config.json` and set the values as needed. If you need to create an instance of ICAT,
there are a number of markdown-formatted tutorials that can be found on the
Expand Down Expand Up @@ -706,7 +706,7 @@ flags `-s` or `--seed` for the seed, and `-y` or `--years` for the number of yea
example: `python -m util.icat_db_generator -s 4 -y 10` Would set the seed to 4 and
generate 10 years of data.

This uses code from the API's Database Backend, so a suitable `DB_URL` should be
This uses code from the API's Database Backend, so a suitable `db_url` should be
configured in `config.json`.


Expand Down
160 changes: 81 additions & 79 deletions datagateway_api/common/config.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
from enum import Enum
import json
import logging
from pathlib import Path
Expand All @@ -9,17 +10,84 @@
log = logging.getLogger()


class APIConfigOptions(Enum):
"""
Class to map config keys to variables in Python - implemented for ease of
development (IntelliSense in IDEs)
"""

BACKEND = "backend"
DB_URL = "db_url"
DEBUG_MODE = "debug_mode"
FLASK_RELOADER = "flask_reloader"
GENERATE_SWAGGER = "generate_swagger"
HOST = "host"
ICAT_CHECK_CERT = "icat_check_cert"
ICAT_URL = "icat_url"
LOG_LEVEL = "log_level"
LOG_LOCATION = "log_location"
PORT = "port"
TEST_MECHANISM = "test_mechanism"
TEST_USER_CREDENTIALS = "test_user_credentials"


class Config(object):
def __init__(self, path=Path(__file__).parent.parent / "config.json"):
self.path = path
with open(self.path) as target:
self.config = json.load(target)
self._config = json.load(target)

self._check_config_items_exist()

def _check_config_items_exist(self):
"""
A function to check that all config options exist before getting too far into
the setup of the API. This check takes the backend into account, meaning only
the config options for the backend used is required

def get_backend_type(self):
Config options used for testing are not checked here as they should only be used
during tests, not in the typical running of the API

If a config option is missing, this will be picked up in `get_config_value()` by
exiting the application
"""
# These keys are non-backend specific and therefore are mandatory for all uses
config_keys = [
APIConfigOptions.BACKEND,
APIConfigOptions.DEBUG_MODE,
APIConfigOptions.FLASK_RELOADER,
APIConfigOptions.GENERATE_SWAGGER,
APIConfigOptions.HOST,
APIConfigOptions.LOG_LEVEL,
APIConfigOptions.LOG_LOCATION,
APIConfigOptions.PORT,
]

if self.get_config_value(APIConfigOptions.BACKEND) == "python_icat":
icat_backend_specific_config_keys = [
APIConfigOptions.ICAT_CHECK_CERT,
APIConfigOptions.ICAT_URL,
]
config_keys.extend(icat_backend_specific_config_keys)
elif self.get_config_value(APIConfigOptions.BACKEND) == "db":
db_backend_specific_config_keys = [APIConfigOptions.DB_URL]
config_keys.extend(db_backend_specific_config_keys)

for key in config_keys:
self.get_config_value(key)

def get_config_value(self, config_key):
"""
Given a config key, the corresponding config value is returned

:param config_key: Enum of a configuration key that's in `config.json`
:type config_key: :class:`APIConfigOptions`
:return: Config value of the given key
"""
try:
return self.config["backend"]
return self._config[config_key.value]
except KeyError:
sys.exit("Missing config value, backend")
sys.exit(f"Missing config value: {config_key}")
MRichards99 marked this conversation as resolved.
Show resolved Hide resolved

def set_backend_type(self, backend_type):
"""
Expand All @@ -32,88 +100,22 @@ def set_backend_type(self, backend_type):
type must be fetched. This must be done using this module (rather than directly
importing and checking the Flask app's config) to avoid circular import issues.
"""
self.config["backend"] = backend_type

def get_db_url(self):
try:
return self.config["DB_URL"]
except KeyError:
sys.exit("Missing config value, DB_URL")

def is_flask_reloader(self):
try:
return self.config["flask_reloader"]
except KeyError:
sys.exit("Missing config value, flask_reloader")

def get_icat_url(self):
try:
return self.config["ICAT_URL"]
except KeyError:
sys.exit("Missing config value, ICAT_URL")

def get_icat_check_cert(self):
try:
return self.config["icat_check_cert"]
except KeyError:
sys.exit("Missing config value, icat_check_cert")

def get_log_level(self):
try:
return self.config["log_level"]
except KeyError:
sys.exit("Missing config value, log_level")

def get_log_location(self):
try:
return self.config["log_location"]
except KeyError:
sys.exit("Missing config value, log_location")

def is_debug_mode(self):
try:
return self.config["debug_mode"]
except KeyError:
sys.exit("Missing config value, debug_mode")

def is_generate_swagger(self):
try:
return self.config["generate_swagger"]
except KeyError:
sys.exit("Missing config value, generate_swagger")

def get_host(self):
try:
return self.config["host"]
except KeyError:
sys.exit("Missing config value, host")

def get_port(self):
try:
return self.config["port"]
except KeyError:
sys.exit("Missing config value, port")

def get_test_user_credentials(self):
try:
return self.config["test_user_credentials"]
except KeyError:
sys.exit("Missing config value, test_user_credentials")

def get_test_mechanism(self):
try:
return self.config["test_mechanism"]
except KeyError:
sys.exit("Missing config value, test_mechanism")
self._config["backend"] = backend_type

def get_icat_properties(self):
"""
ICAT properties can be retrieved using Python ICAT's client object, however this
requires the client object to be authenticated which may not always be the case
when requesting these properties, hence a HTTP request is sent as an alternative
"""
properties_url = f"{config.get_icat_url()}/icat/properties"
r = requests.request("GET", properties_url, verify=config.get_icat_check_cert())
properties_url = (
f"{config.get_config_value(APIConfigOptions.ICAT_URL)}/icat/properties"
)
r = requests.request(
"GET",
properties_url,
verify=config.get_config_value(APIConfigOptions.ICAT_CHECK_CERT),
)
icat_properties = r.json()

return icat_properties
Expand Down
4 changes: 0 additions & 4 deletions datagateway_api/common/constants.py
Original file line number Diff line number Diff line change
@@ -1,10 +1,6 @@
from datetime import datetime

from datagateway_api.common.config import config


class Constants:
DATABASE_URL = config.get_db_url()
PYTHON_ICAT_DISTNCT_CONDITION = "!= null"
ICAT_PROPERTIES = config.get_icat_properties()
TEST_MOD_CREATE_DATETIME = datetime(2000, 1, 1)
4 changes: 2 additions & 2 deletions datagateway_api/common/icat/filters.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
import logging

from datagateway_api.common.constants import Constants
from datagateway_api.common.config import config
from datagateway_api.common.exceptions import FilterError
from datagateway_api.common.filters import (
DistinctFieldFilter,
Expand Down Expand Up @@ -163,7 +163,7 @@ def __init__(self, skip_value):
super().__init__(skip_value)

def apply_filter(self, query):
icat_properties = Constants.ICAT_PROPERTIES
icat_properties = config.get_icat_properties()
icat_set_limit(query, self.skip_value, icat_properties["maxEntities"])


Expand Down
5 changes: 3 additions & 2 deletions datagateway_api/common/icat/helpers.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@
ICATValidationError,
)

from datagateway_api.common.config import config
from datagateway_api.common.config import APIConfigOptions, config
from datagateway_api.common.date_handler import DateHandler
from datagateway_api.common.exceptions import (
AuthenticationError,
Expand Down Expand Up @@ -74,7 +74,8 @@ def wrapper_requires_session(*args, **kwargs):

def create_client():
client = icat.client.Client(
config.get_icat_url(), checkCert=config.get_icat_check_cert(),
config.get_config_value(APIConfigOptions.ICAT_URL),
checkCert=config.get_config_value(APIConfigOptions.ICAT_CHECK_CERT),
)
return client

Expand Down
11 changes: 7 additions & 4 deletions datagateway_api/common/logger_setup.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,9 @@
import logging.config
from pathlib import Path

from datagateway_api.common.config import config
from datagateway_api.common.config import APIConfigOptions, config

LOG_FILE_NAME = Path(config.get_log_location())
LOG_FILE_NAME = Path(config.get_config_value(APIConfigOptions.LOG_LOCATION))
logger_config = {
"version": 1,
"formatters": {
Expand All @@ -14,15 +14,18 @@
},
"handlers": {
"default": {
"level": config.get_log_level(),
"level": config.get_config_value(APIConfigOptions.LOG_LEVEL),
"formatter": "default",
"class": "logging.handlers.RotatingFileHandler",
"filename": LOG_FILE_NAME,
"maxBytes": 5000000,
"backupCount": 10,
},
},
"root": {"level": config.get_log_level(), "handlers": ["default"]},
"root": {
"level": config.get_config_value(APIConfigOptions.LOG_LEVEL),
"handlers": ["default"],
},
}


Expand Down
4 changes: 2 additions & 2 deletions datagateway_api/common/query_filter_factory.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
import logging

from datagateway_api.common.config import config
from datagateway_api.common.config import APIConfigOptions, config
from datagateway_api.common.exceptions import (
ApiError,
FilterError,
Expand All @@ -27,7 +27,7 @@ def get_query_filter(request_filter):
:raises FilterError: If the filter name is not recognised
"""

backend_type = config.get_backend_type()
backend_type = config.get_config_value(APIConfigOptions.BACKEND)
if backend_type == "db":
from datagateway_api.common.database.filters import (
DatabaseDistinctFieldFilter as DistinctFieldFilter,
Expand Down
4 changes: 2 additions & 2 deletions datagateway_api/config.json.example
Original file line number Diff line number Diff line change
@@ -1,8 +1,8 @@
{
"backend": "db",
"DB_URL": "mysql+pymysql://icatdbuser:icatdbuserpw@localhost:3306/icatdb",
"db_url": "mysql+pymysql://icatdbuser:icatdbuserpw@localhost:3306/icatdb",
"flask_reloader": false,
"ICAT_URL": "https://localhost:8181",
"icat_url": "https://localhost:8181",
"icat_check_cert": false,
"log_level": "WARN",
"log_location": "/home/runner/work/datagateway-api/datagateway-api/logs.log",
Expand Down
13 changes: 7 additions & 6 deletions datagateway_api/src/api_start_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,8 +8,7 @@
from flask_swagger_ui import get_swaggerui_blueprint

from datagateway_api.common.backends import create_backend
from datagateway_api.common.config import config
from datagateway_api.common.constants import Constants
from datagateway_api.common.config import APIConfigOptions, config
from datagateway_api.common.database.helpers import db
from datagateway_api.src.resources.entities.entity_endpoint import (
get_count_endpoint,
Expand Down Expand Up @@ -64,10 +63,12 @@ def create_app_infrastructure(flask_app):
backend_type = flask_app.config["TEST_BACKEND"]
config.set_backend_type(backend_type)
except KeyError:
backend_type = config.get_backend_type()
backend_type = config.get_config_value(APIConfigOptions.BACKEND)

if backend_type == "db":
flask_app.config["SQLALCHEMY_DATABASE_URI"] = Constants.DATABASE_URL
flask_app.config["SQLALCHEMY_DATABASE_URI"] = config.get_config_value(
APIConfigOptions.DB_URL,
)
flask_app.config["SQLALCHEMY_TRACK_MODIFICATIONS"] = False
db.init_app(flask_app)

Expand All @@ -81,7 +82,7 @@ def create_api_endpoints(flask_app, api, spec):
backend_type = flask_app.config["TEST_BACKEND"]
config.set_backend_type(backend_type)
except KeyError:
backend_type = config.get_backend_type()
backend_type = config.get_config_value(APIConfigOptions.BACKEND)

backend = create_backend(backend_type)

Expand Down Expand Up @@ -154,7 +155,7 @@ def create_api_endpoints(flask_app, api, spec):
def openapi_config(spec):
# Reorder paths (e.g. get, patch, post) so openapi.yaml only changes when there's a
# change to the Swagger docs, rather than changing on each startup
if config.is_generate_swagger():
if config.get_config_value(APIConfigOptions.GENERATE_SWAGGER):
log.debug("Reordering OpenAPI docs to alphabetical order")
for entity_data in spec._paths.values():
for endpoint_name in sorted(entity_data.keys()):
Expand Down
10 changes: 5 additions & 5 deletions datagateway_api/src/main.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@

from flask import Flask

from datagateway_api.common.config import config
from datagateway_api.common.config import APIConfigOptions, config
from datagateway_api.common.logger_setup import setup_logger
from datagateway_api.src.api_start_utils import (
create_api_endpoints,
Expand All @@ -23,8 +23,8 @@

if __name__ == "__main__":
app.run(
host=config.get_host(),
port=config.get_port(),
debug=config.is_debug_mode(),
use_reloader=config.is_flask_reloader(),
host=config.get_config_value(APIConfigOptions.HOST),
port=config.get_config_value(APIConfigOptions.PORT),
debug=config.get_config_value(APIConfigOptions.DEBUG_MODE),
use_reloader=config.get_config_value(APIConfigOptions.FLASK_RELOADER),
)
Loading