Skip to content

Commit

Permalink
#210: Replace old config getters with calls to generic
Browse files Browse the repository at this point in the history
- This commit also removes the old getters, as they're no longer being used
  • Loading branch information
MRichards99 committed Apr 20, 2021
1 parent c1a02e5 commit 9ad2f57
Show file tree
Hide file tree
Showing 9 changed files with 80 additions and 131 deletions.
84 changes: 4 additions & 80 deletions datagateway_api/common/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,12 +28,6 @@ def get_config_value(self, config_key):
except KeyError:
sys.exit(f"Missing config value: {config_key}")

def get_backend_type(self):
try:
return self.config["backend"]
except KeyError:
sys.exit("Missing config value, backend")

def set_backend_type(self, backend_type):
"""
This setter is used as a way for automated tests to set the backend type. The
Expand All @@ -47,86 +41,16 @@ def set_backend_type(self, backend_type):
"""
self.config["backend"] = backend_type

def get_db_url(self):
try:
return self.config["db_url"]
except KeyError:
sys.exit("Missing config value, db_url")

def is_flask_reloader(self):
try:
return self.config["flask_reloader"]
except KeyError:
sys.exit("Missing config value, flask_reloader")

def get_icat_url(self):
try:
return self.config["icat_url"]
except KeyError:
sys.exit("Missing config value, icat_url")

def get_icat_check_cert(self):
try:
return self.config["icat_check_cert"]
except KeyError:
sys.exit("Missing config value, icat_check_cert")

def get_log_level(self):
try:
return self.config["log_level"]
except KeyError:
sys.exit("Missing config value, log_level")

def get_log_location(self):
try:
return self.config["log_location"]
except KeyError:
sys.exit("Missing config value, log_location")

def is_debug_mode(self):
try:
return self.config["debug_mode"]
except KeyError:
sys.exit("Missing config value, debug_mode")

def is_generate_swagger(self):
try:
return self.config["generate_swagger"]
except KeyError:
sys.exit("Missing config value, generate_swagger")

def get_host(self):
try:
return self.config["host"]
except KeyError:
sys.exit("Missing config value, host")

def get_port(self):
try:
return self.config["port"]
except KeyError:
sys.exit("Missing config value, port")

def get_test_user_credentials(self):
try:
return self.config["test_user_credentials"]
except KeyError:
sys.exit("Missing config value, test_user_credentials")

def get_test_mechanism(self):
try:
return self.config["test_mechanism"]
except KeyError:
sys.exit("Missing config value, test_mechanism")

def get_icat_properties(self):
"""
ICAT properties can be retrieved using Python ICAT's client object, however this
requires the client object to be authenticated which may not always be the case
when requesting these properties, hence a HTTP request is sent as an alternative
"""
properties_url = f"{config.get_icat_url()}/icat/properties"
r = requests.request("GET", properties_url, verify=config.get_icat_check_cert())
properties_url = f"{config.get_config_value('icat_url')}/icat/properties"
r = requests.request(
"GET", properties_url, verify=config.get_config_value("icat_check_cert")
)
icat_properties = r.json()

return icat_properties
Expand Down
3 changes: 2 additions & 1 deletion datagateway_api/common/icat/helpers.py
Original file line number Diff line number Diff line change
Expand Up @@ -74,7 +74,8 @@ def wrapper_requires_session(*args, **kwargs):

def create_client():
client = icat.client.Client(
config.get_icat_url(), checkCert=config.get_icat_check_cert(),
config.get_config_value("icat_url"),
checkCert=config.get_config_value("icat_check_cert"),
)
return client

Expand Down
6 changes: 3 additions & 3 deletions datagateway_api/common/logger_setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@

from datagateway_api.common.config import config

LOG_FILE_NAME = Path(config.get_log_location())
LOG_FILE_NAME = Path(config.get_config_value("log_location"))
logger_config = {
"version": 1,
"formatters": {
Expand All @@ -14,15 +14,15 @@
},
"handlers": {
"default": {
"level": config.get_log_level(),
"level": config.get_config_value("log_level"),
"formatter": "default",
"class": "logging.handlers.RotatingFileHandler",
"filename": LOG_FILE_NAME,
"maxBytes": 5000000,
"backupCount": 10,
},
},
"root": {"level": config.get_log_level(), "handlers": ["default"]},
"root": {"level": config.get_config_value("log_level"), "handlers": ["default"]},
}


Expand Down
8 changes: 4 additions & 4 deletions datagateway_api/src/api_start_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -63,10 +63,10 @@ def create_app_infrastructure(flask_app):
backend_type = flask_app.config["TEST_BACKEND"]
config.set_backend_type(backend_type)
except KeyError:
backend_type = config.get_backend_type()
backend_type = config.get_config_value("backend")

if backend_type == "db":
flask_app.config["SQLALCHEMY_DATABASE_URI"] = config.get_db_url()
flask_app.config["SQLALCHEMY_DATABASE_URI"] = config.get_config_value("db_url")
flask_app.config["SQLALCHEMY_TRACK_MODIFICATIONS"] = False
db.init_app(flask_app)

Expand All @@ -80,7 +80,7 @@ def create_api_endpoints(flask_app, api, spec):
backend_type = flask_app.config["TEST_BACKEND"]
config.set_backend_type(backend_type)
except KeyError:
backend_type = config.get_backend_type()
backend_type = config.get_config_value("backend")

backend = create_backend(backend_type)

Expand Down Expand Up @@ -153,7 +153,7 @@ def create_api_endpoints(flask_app, api, spec):
def openapi_config(spec):
# Reorder paths (e.g. get, patch, post) so openapi.yaml only changes when there's a
# change to the Swagger docs, rather than changing on each startup
if config.is_generate_swagger():
if config.get_config_value("generate_swagger"):
log.debug("Reordering OpenAPI docs to alphabetical order")
for entity_data in spec._paths.values():
for endpoint_name in sorted(entity_data.keys()):
Expand Down
8 changes: 4 additions & 4 deletions datagateway_api/src/main.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,8 +23,8 @@

if __name__ == "__main__":
app.run(
host=config.get_host(),
port=config.get_port(),
debug=config.is_debug_mode(),
use_reloader=config.is_flask_reloader(),
host=config.get_config_value("host"),
port=config.get_config_value("port"),
debug=config.get_config_value("debug_mode"),
use_reloader=config.get_config_value("flask_reloader"),
)
10 changes: 8 additions & 2 deletions test/icat/conftest.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,8 +18,14 @@

@pytest.fixture(scope="package")
def icat_client():
client = Client(config.get_icat_url(), checkCert=config.get_icat_check_cert())
client.login(config.get_test_mechanism(), config.get_test_user_credentials())
client = Client(
config.get_config_value("icat_url"),
checkCert=config.get_config_value("icat_check_cert"),
)
client.login(
config.get_config_value("test_mechanism"),
config.get_config_value("test_user_credentials"),
)
return client


Expand Down
35 changes: 25 additions & 10 deletions test/icat/test_session_handling.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,8 +27,9 @@ def test_get_valid_session_details(

# Check username is correct
assert (
session_details.json["username"] == f"{config.get_test_mechanism()}/"
f"{config.get_test_user_credentials()['username']}"
session_details.json["username"]
== f"{config.get_config_value('test_mechanism')}/"
f"{config.get_config_value('test_user_credentials')['username']}"
)

# Check session ID matches the header from the request
Expand Down Expand Up @@ -72,16 +73,24 @@ def test_refresh_session(self, valid_icat_credentials_header, flask_test_app_ica
[
pytest.param(
{
"username": config.get_test_user_credentials()["username"],
"password": config.get_test_user_credentials()["password"],
"mechanism": config.get_test_mechanism(),
"username": config.get_config_value("test_user_credentials")[
"username"
],
"password": config.get_config_value("test_user_credentials")[
"password"
],
"mechanism": config.get_config_value("test_mechanism"),
},
id="Normal request body",
),
pytest.param(
{
"username": config.get_test_user_credentials()["username"],
"password": config.get_test_user_credentials()["password"],
"username": config.get_config_value("test_user_credentials")[
"username"
],
"password": config.get_config_value("test_user_credentials")[
"password"
],
},
id="Missing mechanism in request body",
),
Expand Down Expand Up @@ -110,7 +119,7 @@ def test_valid_login(
{
"username": "Invalid Username",
"password": "InvalidPassword",
"mechanism": config.get_test_mechanism(),
"mechanism": config.get_config_value("test_mechanism"),
},
403,
id="Invalid credentials",
Expand All @@ -126,8 +135,14 @@ def test_invalid_login(
assert login_response.status_code == expected_response_code

def test_valid_logout(self, flask_test_app_icat):
client = Client(config.get_icat_url(), checkCert=config.get_icat_check_cert())
client.login(config.get_test_mechanism(), config.get_test_user_credentials())
client = Client(
config.get_config_value("icat_url"),
checkCert=config.get_config_value("icat_check_cert"),
)
client.login(
config.get_config_value("test_mechanism"),
config.get_config_value("test_user_credentials"),
)
creds_header = {"Authorization": f"Bearer {client.sessionId}"}

logout_response = flask_test_app_icat.delete("/sessions", headers=creds_header)
Expand Down
Loading

0 comments on commit 9ad2f57

Please sign in to comment.