Skip to content

Commit

Permalink
refactor: move config to its own class #258
Browse files Browse the repository at this point in the history
- This change has been made so the config can be mocked in tests
  • Loading branch information
MRichards99 committed Nov 25, 2021
1 parent 71cc98a commit 1cf73e4
Show file tree
Hide file tree
Showing 33 changed files with 200 additions and 182 deletions.
28 changes: 15 additions & 13 deletions datagateway_api/src/api_start_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,12 +7,12 @@
from flask_restful import Api
from flask_swagger_ui import get_swaggerui_blueprint

from datagateway_api.src.common.config import config
from datagateway_api.src.common.config import Config

# Only attempt to create a DataGateway API backend if the datagateway_api object
# is present in the config. This ensures that the API does not error on startup
# due to an AttributeError exception being thrown if the object is missing.
if config.datagateway_api is not None:
if Config.config.datagateway_api is not None:
from datagateway_api.src.datagateway_api.backends import create_backend
from datagateway_api.src.datagateway_api.database.helpers import db # noqa: I202
from datagateway_api.src.datagateway_api.icat.icat_client_pool import create_client_pool
Expand Down Expand Up @@ -73,15 +73,17 @@ def create_app_infrastructure(flask_app):
flask_app.url_map.strict_slashes = False
api = CustomErrorHandledApi(flask_app)

if config.datagateway_api is not None:
if Config.config.datagateway_api is not None:
try:
backend_type = flask_app.config["TEST_BACKEND"]
config.datagateway_api.set_backend_type(backend_type)
Config.config.datagateway_api.set_backend_type(backend_type)
except KeyError:
backend_type = config.datagateway_api.backend
backend_type = Config.config.datagateway_api.backend

if backend_type == "db":
flask_app.config["SQLALCHEMY_DATABASE_URI"] = config.datagateway_api.db_url
flask_app.config[
"SQLALCHEMY_DATABASE_URI"
] = Config.config.datagateway_api.db_url
flask_app.config["SQLALCHEMY_TRACK_MODIFICATIONS"] = False
db.init_app(flask_app)

Expand All @@ -92,12 +94,12 @@ def create_app_infrastructure(flask_app):

def create_api_endpoints(flask_app, api, spec):
# DataGateway API endpoints
if config.datagateway_api is not None:
if Config.config.datagateway_api is not None:
try:
backend_type = flask_app.config["TEST_BACKEND"]
config.datagateway_api.set_backend_type(backend_type)
Config.config.datagateway_api.set_backend_type(backend_type)
except KeyError:
backend_type = config.datagateway_api.backend
backend_type = Config.config.datagateway_api.backend

backend = create_backend(backend_type)

Expand All @@ -106,7 +108,7 @@ def create_api_endpoints(flask_app, api, spec):
# Create client pool
icat_client_pool = create_client_pool()

datagateway_api_extension = config.datagateway_api.extension
datagateway_api_extension = Config.config.datagateway_api.extension
for entity_name in endpoints:
get_endpoint_resource = get_endpoint(
entity_name,
Expand Down Expand Up @@ -220,8 +222,8 @@ def create_api_endpoints(flask_app, api, spec):
spec.path(resource=ping_resource, api=api)

# Search API endpoints
if config.search_api is not None:
search_api_extension = config.search_api.extension
if Config.config.search_api is not None:
search_api_extension = Config.config.search_api.extension
search_api_entity_endpoints = ["datasets", "documents", "instruments"]

for entity_name in search_api_entity_endpoints:
Expand Down Expand Up @@ -271,7 +273,7 @@ def create_api_endpoints(flask_app, api, spec):
def openapi_config(spec):
# Reorder paths (e.g. get, patch, post) so openapi.yaml only changes when there's a
# change to the Swagger docs, rather than changing on each startup
if config.generate_swagger:
if Config.config.generate_swagger:
log.debug("Reordering OpenAPI docs to alphabetical order")
for entity_data in spec._paths.values():
for endpoint_name in sorted(entity_data.keys()):
Expand Down
7 changes: 4 additions & 3 deletions datagateway_api/src/common/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -126,8 +126,6 @@ class SearchAPI(BaseModel):
validation of the SearchAPI config data using Python type annotations.
"""

client_pool_init_size: StrictInt
client_pool_max_size: StrictInt
extension: StrictStr
icat_check_cert: StrictBool
icat_url: StrictStr
Expand Down Expand Up @@ -215,4 +213,7 @@ def validate_api_extensions(cls, value, values): # noqa: B902, N805
return value


config = APIConfig.load()
class Config:
"""Class containing config as a class variable so it can mocked during testing"""

config = APIConfig.load()
8 changes: 4 additions & 4 deletions datagateway_api/src/common/logger_setup.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,9 @@
import logging.config
from pathlib import Path

from datagateway_api.src.common.config import config
from datagateway_api.src.common.config import Config

LOG_FILE_NAME = Path(config.log_location)
LOG_FILE_NAME = Path(Config.config.log_location)
logger_config = {
"version": 1,
"formatters": {
Expand All @@ -14,15 +14,15 @@
},
"handlers": {
"default": {
"level": config.log_level,
"level": Config.config.log_level,
"formatter": "default",
"class": "logging.handlers.RotatingFileHandler",
"filename": LOG_FILE_NAME,
"maxBytes": 5000000,
"backupCount": 10,
},
},
"root": {"level": config.log_level, "handlers": ["default"]},
"root": {"level": Config.config.log_level, "handlers": ["default"]},
}


Expand Down
5 changes: 3 additions & 2 deletions datagateway_api/src/datagateway_api/icat/filters.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
import logging

from datagateway_api.src.common.config import config
from datagateway_api.src.common.config import Config
from datagateway_api.src.common.exceptions import FilterError
from datagateway_api.src.common.filters import (
DistinctFieldFilter,
Expand Down Expand Up @@ -215,7 +215,8 @@ def __init__(self, skip_value):

def apply_filter(self, query):
icat_properties = get_icat_properties(
config.datagateway_api.icat_url, config.datagateway_api.icat_check_cert,
Config.config.datagateway_api.icat_url,
Config.config.datagateway_api.icat_check_cert,
)
icat_set_limit(query, self.skip_value, icat_properties["maxEntities"])

Expand Down
14 changes: 7 additions & 7 deletions datagateway_api/src/datagateway_api/icat/icat_client_pool.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
from icat.client import Client
from object_pool import ObjectPool

from datagateway_api.src.common.config import config
from datagateway_api.src.common.config import Config

log = logging.getLogger()

Expand All @@ -13,12 +13,12 @@ class ICATClient(Client):

def __init__(self, client_use="datagateway_api"):
if client_use == "datagateway_api":
icat_url = config.datagateway_api.icat_url
icat_check_cert = config.datagateway_api.icat_check_cert
icat_url = Config.config.datagateway_api.icat_url
icat_check_cert = Config.config.datagateway_api.icat_check_cert
else:
# Search API use cases
icat_url = config.search_api.icat_url
icat_check_cert = config.search_api.icat_check_cert
icat_url = Config.config.search_api.icat_url
icat_check_cert = Config.config.search_api.icat_check_cert

super().__init__(icat_url, checkCert=icat_check_cert)
# When clients are cleaned up, sessions won't be logged out
Expand All @@ -41,8 +41,8 @@ def create_client_pool():

return ObjectPool(
ICATClient,
min_init=config.datagateway_api.client_pool_init_size,
max_capacity=config.datagateway_api.client_pool_max_size,
min_init=Config.config.datagateway_api.client_pool_init_size,
max_capacity=Config.config.datagateway_api.client_pool_max_size,
max_reusable=0,
expires=0,
)
4 changes: 2 additions & 2 deletions datagateway_api/src/datagateway_api/icat/lru_cache.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@

from cachetools.lru import LRUCache

from datagateway_api.src.common.config import config
from datagateway_api.src.common.config import Config

log = logging.getLogger()

Expand All @@ -19,7 +19,7 @@ class ExtendedLRUCache(LRUCache):
"""

def __init__(self):
super().__init__(maxsize=config.datagateway_api.client_cache_size)
super().__init__(maxsize=Config.config.datagateway_api.client_cache_size)

def popitem(self):
key, client = super().popitem()
Expand Down
4 changes: 2 additions & 2 deletions datagateway_api/src/datagateway_api/query_filter_factory.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
import logging

from datagateway_api.src.common.config import config
from datagateway_api.src.common.config import Config
from datagateway_api.src.common.exceptions import (
ApiError,
FilterError,
Expand All @@ -27,7 +27,7 @@ def get_query_filter(request_filter):
:raises FilterError: If the filter name is not recognised
"""

backend_type = config.datagateway_api.backend
backend_type = Config.config.datagateway_api.backend
if backend_type == "db":
from datagateway_api.src.datagateway_api.database.filters import (
DatabaseDistinctFieldFilter as DistinctFieldFilter,
Expand Down
10 changes: 5 additions & 5 deletions datagateway_api/src/main.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@
create_openapi_endpoint,
openapi_config,
)
from datagateway_api.src.common.config import config
from datagateway_api.src.common.config import Config
from datagateway_api.src.common.logger_setup import setup_logger

setup_logger()
Expand All @@ -23,8 +23,8 @@

if __name__ == "__main__":
app.run(
host=config.host,
port=config.port,
debug=config.debug_mode,
use_reloader=config.flask_reloader,
host=Config.config.host,
port=Config.config.port,
debug=Config.config.debug_mode,
use_reloader=Config.config.flask_reloader,
)
8 changes: 5 additions & 3 deletions test/datagateway_api/db/conftest.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@

import pytest

from datagateway_api.src.common.config import config
from datagateway_api.src.common.config import Config
from datagateway_api.src.common.constants import Constants
from datagateway_api.src.datagateway_api.database.helpers import (
delete_row_by_id,
Expand Down Expand Up @@ -117,7 +117,8 @@ def isis_specific_endpoint_data_db():
@pytest.fixture()
def final_instrument_id(flask_test_app_db, valid_db_credentials_header):
final_instrument_result = flask_test_app_db.get(
f'{config.datagateway_api.extension}/instruments/findone?order="id DESC"',
f"{Config.config.datagateway_api.extension}/instruments/findone"
'?order="id DESC"',
headers=valid_db_credentials_header,
)
return final_instrument_result.json["id"]
Expand All @@ -126,7 +127,8 @@ def final_instrument_id(flask_test_app_db, valid_db_credentials_header):
@pytest.fixture()
def final_facilitycycle_id(flask_test_app_db, valid_db_credentials_header):
final_facilitycycle_result = flask_test_app_db.get(
f'{config.datagateway_api.extension}/facilitycycles/findone?order="id DESC"',
f"{Config.config.datagateway_api.extension}/facilitycycles/findone"
'?order="id DESC"',
headers=valid_db_credentials_header,
)
return final_facilitycycle_result.json["id"]
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
import pytest

from datagateway_api.src.common.config import config
from datagateway_api.src.common.config import Config


class TestDBCountWithFilters:
Expand All @@ -9,7 +9,7 @@ def test_valid_count_with_filters(
self, flask_test_app_db, valid_db_credentials_header,
):
test_response = flask_test_app_db.get(
f"{config.datagateway_api.extension}/investigations/count?where="
f"{Config.config.datagateway_api.extension}/investigations/count?where="
'{"title": {"like": "Title for DataGateway API Testing (DB)"}}',
headers=valid_db_credentials_header,
)
Expand All @@ -20,7 +20,7 @@ def test_valid_no_results_count_with_filters(
self, flask_test_app_db, valid_db_credentials_header,
):
test_response = flask_test_app_db.get(
f"{config.datagateway_api.extension}/investigations/count?where="
f"{Config.config.datagateway_api.extension}/investigations/count?where="
'{"title": {"like": "This filter should cause a404 for testing '
'purposes..."}}',
headers=valid_db_credentials_header,
Expand Down
6 changes: 3 additions & 3 deletions test/datagateway_api/db/endpoints/test_findone_db.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
from datagateway_api.src.common.config import config
from datagateway_api.src.common.config import Config


class TestDBFindone:
Expand All @@ -9,7 +9,7 @@ def test_valid_findone_with_filters(
single_investigation_test_data_db,
):
test_response = flask_test_app_db.get(
f"{config.datagateway_api.extension}/investigations/findone?where="
f"{Config.config.datagateway_api.extension}/investigations/findone?where="
'{"title": {"like": "Title for DataGateway API Testing (DB)"}}',
headers=valid_db_credentials_header,
)
Expand All @@ -20,7 +20,7 @@ def test_valid_no_results_findone_with_filters(
self, flask_test_app_db, valid_db_credentials_header,
):
test_response = flask_test_app_db.get(
f"{config.datagateway_api.extension}/investigations/findone?where="
f"{Config.config.datagateway_api.extension}/investigations/findone?where="
'{"title": {"eq": "This filter should cause a404 for testing '
'purposes..."}}',
headers=valid_db_credentials_header,
Expand Down
11 changes: 6 additions & 5 deletions test/datagateway_api/db/endpoints/test_get_by_id_db.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
from datagateway_api.src.common.config import config
from datagateway_api.src.common.config import Config


class TestDBGetByID:
Expand All @@ -10,14 +10,14 @@ def test_valid_get_with_id(
):
# Need to identify the ID given to the test data
investigation_data = flask_test_app_db.get(
f"{config.datagateway_api.extension}/investigations?where="
f"{Config.config.datagateway_api.extension}/investigations?where="
'{"title": {"like": "Title for DataGateway API Testing (DB)"}}',
headers=valid_db_credentials_header,
)
test_data_id = investigation_data.json[0]["id"]

test_response = flask_test_app_db.get(
f"{config.datagateway_api.extension}/investigations/{test_data_id}",
f"{Config.config.datagateway_api.extension}/investigations/{test_data_id}",
headers=valid_db_credentials_header,
)

Expand All @@ -27,15 +27,16 @@ def test_invalid_get_with_id(
self, flask_test_app_db, valid_db_credentials_header,
):
final_investigation_result = flask_test_app_db.get(
f"{config.datagateway_api.extension}/investigations/findone?order="
f"{Config.config.datagateway_api.extension}/investigations/findone?order="
'"id DESC"',
headers=valid_db_credentials_header,
)
test_data_id = final_investigation_result.json["id"]

# Adding 100 onto the ID to the most recent result should ensure a 404
test_response = flask_test_app_db.get(
f"{config.datagateway_api.extension}/investigations/{test_data_id + 100}",
f"{Config.config.datagateway_api.extension}/investigations"
f"/{test_data_id + 100}",
headers=valid_db_credentials_header,
)

Expand Down
Loading

0 comments on commit 1cf73e4

Please sign in to comment.