Skip to content

Commit

Permalink
#210: Implement config enum class
Browse files Browse the repository at this point in the history
  • Loading branch information
MRichards99 committed Apr 20, 2021
1 parent 171dbdf commit 76bcb7d
Show file tree
Hide file tree
Showing 10 changed files with 115 additions and 95 deletions.
46 changes: 25 additions & 21 deletions datagateway_api/common/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -56,27 +56,27 @@ def _check_config_items_exist(self):
"""
# These keys are non-backend specific and therefore are mandatory for all uses
config_keys = [
"backend",
"debug_mode",
"flask_reloader",
"generate_swagger",
"host",
"log_level",
"log_location",
"port",
APIConfigOptions.BACKEND,
APIConfigOptions.DEBUG_MODE,
APIConfigOptions.FLASK_RELOADER,
APIConfigOptions.GENERATE_SWAGGER,
APIConfigOptions.HOST,
APIConfigOptions.LOG_LEVEL,
APIConfigOptions.LOG_LOCATION,
APIConfigOptions.PORT,
]

if self.get_config_value("backend") == "python_icat":
if self.get_config_value(APIConfigOptions.BACKEND) == "python_icat":
icat_backend_specific_config_keys = [
"client_cache_size",
"client_pool_init_size",
"client_pool_max_size",
"icat_check_cert",
"icat_url",
APIConfigOptions.CLIENT_CACHE_SIZE,
APIConfigOptions.CLIENT_POOL_INIT_SIZE,
APIConfigOptions.CLIENT_POOL_MAX_SIZE,
APIConfigOptions.ICAT_CHECK_CERT,
APIConfigOptions.ICAT_URL,
]
config_keys.extend(icat_backend_specific_config_keys)
elif self.get_config_value("backend") == "db":
db_backend_specific_config_keys = ["db_url"]
elif self.get_config_value(APIConfigOptions.BACKEND) == "db":
db_backend_specific_config_keys = [APIConfigOptions.DB_URL]
config_keys.extend(db_backend_specific_config_keys)

for key in config_keys:
Expand All @@ -86,12 +86,12 @@ def get_config_value(self, config_key):
"""
Given a config key, the corresponding config value is returned
:param config_key: Configuration key that matches the contents of `config.json`
:type config_key: :class:`str`
:param config_key: Enum of a configuration key that's in `config.json`
:type config_key: :class:`APIConfigOptions`
:return: Config value of the given key
"""
try:
return self.config[config_key]
return self.config[config_key.value]
except KeyError:
sys.exit(f"Missing config value: {config_key}")

Expand All @@ -114,9 +114,13 @@ def get_icat_properties(self):
requires the client object to be authenticated which may not always be the case
when requesting these properties, hence a HTTP request is sent as an alternative
"""
properties_url = f"{config.get_config_value('icat_url')}/icat/properties"
properties_url = (
f"{config.get_config_value(APIConfigOptions.ICAT_URL)}/icat/properties"
)
r = requests.request(
"GET", properties_url, verify=config.get_config_value("icat_check_cert")
"GET",
properties_url,
verify=config.get_config_value(APIConfigOptions.ICAT_CHECK_CERT),
)
icat_properties = r.json()

Expand Down
6 changes: 3 additions & 3 deletions datagateway_api/common/icat/helpers.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@
ICATValidationError,
)

from datagateway_api.common.config import config
from datagateway_api.common.config import APIConfigOptions, config
from datagateway_api.common.date_handler import DateHandler
from datagateway_api.common.exceptions import (
AuthenticationError,
Expand Down Expand Up @@ -74,8 +74,8 @@ def wrapper_requires_session(*args, **kwargs):

def create_client():
client = icat.client.Client(
config.get_config_value("icat_url"),
checkCert=config.get_config_value("icat_check_cert"),
config.get_config_value(APIConfigOptions.ICAT_URL),
checkCert=config.get_config_value(APIConfigOptions.ICAT_CHECK_CERT),
)
return client

Expand Down
11 changes: 7 additions & 4 deletions datagateway_api/common/logger_setup.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,9 @@
import logging.config
from pathlib import Path

from datagateway_api.common.config import config
from datagateway_api.common.config import APIConfigOptions, config

LOG_FILE_NAME = Path(config.get_config_value("log_location"))
LOG_FILE_NAME = Path(config.get_config_value(APIConfigOptions.LOG_LOCATION))
logger_config = {
"version": 1,
"formatters": {
Expand All @@ -14,15 +14,18 @@
},
"handlers": {
"default": {
"level": config.get_config_value("log_level"),
"level": config.get_config_value(APIConfigOptions.LOG_LEVEL),
"formatter": "default",
"class": "logging.handlers.RotatingFileHandler",
"filename": LOG_FILE_NAME,
"maxBytes": 5000000,
"backupCount": 10,
},
},
"root": {"level": config.get_config_value("log_level"), "handlers": ["default"]},
"root": {
"level": config.get_config_value(APIConfigOptions.LOG_LEVEL),
"handlers": ["default"],
},
}


Expand Down
4 changes: 2 additions & 2 deletions datagateway_api/common/query_filter_factory.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
import logging

from datagateway_api.common.config import config
from datagateway_api.common.config import APIConfigOptions, config
from datagateway_api.common.exceptions import (
ApiError,
FilterError,
Expand All @@ -27,7 +27,7 @@ def get_query_filter(request_filter):
:raises FilterError: If the filter name is not recognised
"""

backend_type = config.get_config_value("backend")
backend_type = config.get_config_value(APIConfigOptions.BACKEND)
if backend_type == "db":
from datagateway_api.common.database.filters import (
DatabaseDistinctFieldFilter as DistinctFieldFilter,
Expand Down
12 changes: 7 additions & 5 deletions datagateway_api/src/api_start_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@
from flask_swagger_ui import get_swaggerui_blueprint

from datagateway_api.common.backends import create_backend
from datagateway_api.common.config import config
from datagateway_api.common.config import APIConfigOptions, config
from datagateway_api.common.database.helpers import db
from datagateway_api.src.resources.entities.entity_endpoint import (
get_count_endpoint,
Expand Down Expand Up @@ -63,10 +63,12 @@ def create_app_infrastructure(flask_app):
backend_type = flask_app.config["TEST_BACKEND"]
config.set_backend_type(backend_type)
except KeyError:
backend_type = config.get_config_value("backend")
backend_type = config.get_config_value(APIConfigOptions.BACKEND)

if backend_type == "db":
flask_app.config["SQLALCHEMY_DATABASE_URI"] = config.get_config_value("db_url")
flask_app.config["SQLALCHEMY_DATABASE_URI"] = config.get_config_value(
APIConfigOptions.DB_URL
)
flask_app.config["SQLALCHEMY_TRACK_MODIFICATIONS"] = False
db.init_app(flask_app)

Expand All @@ -80,7 +82,7 @@ def create_api_endpoints(flask_app, api, spec):
backend_type = flask_app.config["TEST_BACKEND"]
config.set_backend_type(backend_type)
except KeyError:
backend_type = config.get_config_value("backend")
backend_type = config.get_config_value(APIConfigOptions.BACKEND)

backend = create_backend(backend_type)

Expand Down Expand Up @@ -153,7 +155,7 @@ def create_api_endpoints(flask_app, api, spec):
def openapi_config(spec):
# Reorder paths (e.g. get, patch, post) so openapi.yaml only changes when there's a
# change to the Swagger docs, rather than changing on each startup
if config.get_config_value("generate_swagger"):
if config.get_config_value(APIConfigOptions.GENERATE_SWAGGER):
log.debug("Reordering OpenAPI docs to alphabetical order")
for entity_data in spec._paths.values():
for endpoint_name in sorted(entity_data.keys()):
Expand Down
10 changes: 5 additions & 5 deletions datagateway_api/src/main.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@

from flask import Flask

from datagateway_api.common.config import config
from datagateway_api.common.config import APIConfigOptions, config
from datagateway_api.common.logger_setup import setup_logger
from datagateway_api.src.api_start_utils import (
create_api_endpoints,
Expand All @@ -23,8 +23,8 @@

if __name__ == "__main__":
app.run(
host=config.get_config_value("host"),
port=config.get_config_value("port"),
debug=config.get_config_value("debug_mode"),
use_reloader=config.get_config_value("flask_reloader"),
host=config.get_config_value(APIConfigOptions.HOST),
port=config.get_config_value(APIConfigOptions.PORT),
debug=config.get_config_value(APIConfigOptions.DEBUG_MODE),
use_reloader=config.get_config_value(APIConfigOptions.FLASK_RELOADER),
)
10 changes: 5 additions & 5 deletions test/icat/conftest.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@
from icat.query import Query
import pytest

from datagateway_api.common.config import config
from datagateway_api.common.config import APIConfigOptions, config
from datagateway_api.src.api_start_utils import (
create_api_endpoints,
create_app_infrastructure,
Expand All @@ -19,12 +19,12 @@
@pytest.fixture(scope="package")
def icat_client():
client = Client(
config.get_config_value("icat_url"),
checkCert=config.get_config_value("icat_check_cert"),
config.get_config_value(APIConfigOptions.ICAT_URL),
checkCert=config.get_config_value(APIConfigOptions.ICAT_CHECK_CERT),
)
client.login(
config.get_config_value("test_mechanism"),
config.get_config_value("test_user_credentials"),
config.get_config_value(APIConfigOptions.TEST_MECHANISM),
config.get_config_value(APIConfigOptions.TEST_USER_CREDENTIALS),
)
return client

Expand Down
47 changes: 26 additions & 21 deletions test/icat/test_session_handling.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
from icat.client import Client
import pytest

from datagateway_api.common.config import config
from datagateway_api.common.config import APIConfigOptions, config
from datagateway_api.common.icat.filters import PythonICATWhereFilter


Expand All @@ -28,8 +28,8 @@ def test_get_valid_session_details(
# Check username is correct
assert (
session_details.json["username"]
== f"{config.get_config_value('test_mechanism')}/"
f"{config.get_config_value('test_user_credentials')['username']}"
== f"{config.get_config_value(APIConfigOptions.TEST_MECHANISM)}/"
f"{config.get_config_value(APIConfigOptions.TEST_USER_CREDENTIALS)['username']}"
)

# Check session ID matches the header from the request
Expand Down Expand Up @@ -73,24 +73,26 @@ def test_refresh_session(self, valid_icat_credentials_header, flask_test_app_ica
[
pytest.param(
{
"username": config.get_config_value("test_user_credentials")[
"username"
],
"password": config.get_config_value("test_user_credentials")[
"password"
],
"mechanism": config.get_config_value("test_mechanism"),
"username": config.get_config_value(
APIConfigOptions.TEST_USER_CREDENTIALS
)["username"],
"password": config.get_config_value(
APIConfigOptions.TEST_USER_CREDENTIALS
)["password"],
"mechanism": config.get_config_value(
APIConfigOptions.TEST_MECHANISM
),
},
id="Normal request body",
),
pytest.param(
{
"username": config.get_config_value("test_user_credentials")[
"username"
],
"password": config.get_config_value("test_user_credentials")[
"password"
],
"username": config.get_config_value(
APIConfigOptions.TEST_USER_CREDENTIALS
)["username"],
"password": config.get_config_value(
APIConfigOptions.TEST_USER_CREDENTIALS
)["password"],
},
id="Missing mechanism in request body",
),
Expand All @@ -99,6 +101,7 @@ def test_refresh_session(self, valid_icat_credentials_header, flask_test_app_ica
def test_valid_login(
self, flask_test_app_icat, icat_client, icat_query, request_body,
):
print(request_body)
login_response = flask_test_app_icat.post("/sessions", json=request_body)

icat_client.sessionId = login_response.json["sessionID"]
Expand All @@ -119,7 +122,9 @@ def test_valid_login(
{
"username": "Invalid Username",
"password": "InvalidPassword",
"mechanism": config.get_config_value("test_mechanism"),
"mechanism": config.get_config_value(
APIConfigOptions.TEST_MECHANISM
),
},
403,
id="Invalid credentials",
Expand All @@ -136,12 +141,12 @@ def test_invalid_login(

def test_valid_logout(self, flask_test_app_icat):
client = Client(
config.get_config_value("icat_url"),
checkCert=config.get_config_value("icat_check_cert"),
config.get_config_value(APIConfigOptions.ICAT_URL),
checkCert=config.get_config_value(APIConfigOptions.ICAT_CHECK_CERT),
)
client.login(
config.get_config_value("test_mechanism"),
config.get_config_value("test_user_credentials"),
config.get_config_value(APIConfigOptions.TEST_MECHANISM),
config.get_config_value(APIConfigOptions.TEST_USER_CREDENTIALS),
)
creds_header = {"Authorization": f"Bearer {client.sessionId}"}

Expand Down
Loading

0 comments on commit 76bcb7d

Please sign in to comment.