Skip to content

Commit

Permalink
Improve design of config sections
Browse files Browse the repository at this point in the history
  • Loading branch information
JimMadge committed Nov 21, 2023
1 parent a741542 commit acb01b8
Show file tree
Hide file tree
Showing 4 changed files with 104 additions and 67 deletions.
79 changes: 60 additions & 19 deletions data_safe_haven/config/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,9 +2,10 @@
from __future__ import annotations

import pathlib
from typing import ClassVar

import yaml
from pydantic import BaseModel, Field, ValidationError, computed_field
from pydantic import BaseModel, Field, ValidationError, computed_field, field_validator
from yaml import YAMLError

from data_safe_haven import __version__
Expand Down Expand Up @@ -39,6 +40,17 @@ class ConfigSectionAzure(BaseModel, validate_assignment=True):
subscription_id: Guid
tenant_id: Guid

@classmethod
def from_context(
cls, context: Context, subscription_id: Guid, tenant_id: Guid
) -> ConfigSectionAzure:
return ConfigSectionAzure(
admin_group_id=context.admin_group_id,
location=context.location,
subscription_id=subscription_id,
tenant_id=tenant_id,
)


class ConfigSectionPulumi(BaseModel, validate_assignment=True):
encryption_key_name: str = "pulumi-encryption-key"
Expand All @@ -55,14 +67,33 @@ class ConfigSectionSHM(BaseModel, validate_assignment=True):
name: str
timezone: TimeZone

@classmethod
def from_context(
cls,
context: Context,
aad_tenant_id: Guid,
admin_email_address: EmailAdress,
admin_ip_addresses: list[IpAddress],
fqdn: str,
timezone: TimeZone,
) -> ConfigSectionSHM:
return ConfigSectionSHM(
aad_tenant_id=aad_tenant_id,
admin_email_address=admin_email_address,
admin_ip_addresses=admin_ip_addresses,
fqdn=fqdn,
name=context.shm_name,
timezone=timezone,
)

def update(
self,
*,
aad_tenant_id: str | None = None,
admin_email_address: str | None = None,
admin_ip_addresses: list[str] | None = None,
fqdn: str | None = None,
timezone: str | None = None,
timezone: TimeZone | None = None,
) -> None:
"""Update SHM settings
Expand Down Expand Up @@ -132,30 +163,38 @@ def update(


class ConfigSectionSRE(BaseModel, validate_assignment=True):
databases: list[DatabaseSystem]
data_provider_ip_addresses: list[IpAddress]
databases: list[DatabaseSystem] = Field(default_factory=list[DatabaseSystem])
data_provider_ip_addresses: list[IpAddress] = Field(default_factory=list[IpAddress])
index: int = Field(ge=0)
remote_desktop: ConfigSubsectionRemoteDesktopOpts
workspace_skus: list[AzureVmSku]
research_user_ip_addresses: list[IpAddress]
remote_desktop: ConfigSubsectionRemoteDesktopOpts = Field(
default_factory=ConfigSubsectionRemoteDesktopOpts
)
workspace_skus: list[AzureVmSku] = Field(default_factory=list[AzureVmSku])
research_user_ip_addresses: list[IpAddress] = Field(default_factory=list[IpAddress])
software_packages: SoftwarePackageCategory = SoftwarePackageCategory.NONE

@field_validator("databases")
@classmethod
def all_databases_must_be_unique(
cls, v: list[DatabaseSystem]
) -> list[DatabaseSystem]:
if len(v) != len(set(v)):
msg = "all databases must be unique"
raise ValueError(msg)
return v

def update(
self,
*,
allow_copy: bool | None = None,
allow_paste: bool | None = None,
data_provider_ip_addresses: list[str] | None = None,
data_provider_ip_addresses: list[IpAddress] | None = None,
databases: list[DatabaseSystem] | None = None,
workspace_skus: list[str] | None = None,
workspace_skus: list[AzureVmSku] | None = None,
software_packages: SoftwarePackageCategory | None = None,
user_ip_addresses: list[str] | None = None,
user_ip_addresses: list[IpAddress] | None = None,
) -> None:
"""Update SRE settings
Args:
allow_copy: Allow/deny copying text out of the SRE
allow_paste: Allow/deny pasting text into the SRE
databases: List of database systems to deploy
data_provider_ip_addresses: List of IP addresses belonging to data providers
workspace_skus: List of VM SKUs for workspaces
Expand All @@ -177,8 +216,6 @@ def update(
logger.info(
f"[bold]Databases available to users[/] will be [green]{[database.value for database in self.databases]}[/]."
)
# Pass allow_copy and allow_paste to remote desktop
self.remote_desktop.update(allow_copy=allow_copy, allow_paste=allow_paste)
# Set research desktop SKUs
if workspace_skus:
self.workspace_skus = workspace_skus
Expand All @@ -199,9 +236,13 @@ def update(

class ConfigSectionTags(BaseModel, validate_assignment=True):
deployment: str
deployed_by: str = "Python"
project: str = "Data Safe Haven"
version: str = __version__
deployed_by: ClassVar[str] = "Python"
project: ClassVar[str] = "Data Safe Haven"
version: ClassVar[str] = __version__

@classmethod
def from_context(cls, context: Context) -> ConfigSectionTags:
return ConfigSectionTags(deployment=context.name)


class Config(BaseModel, validate_assignment=True):
Expand Down
12 changes: 8 additions & 4 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -77,12 +77,12 @@ dependencies = [
typing = "mypy {args:data_safe_haven}"

style = [
"ruff {args:data_safe_haven}",
"black --check --diff {args:data_safe_haven}",
"ruff {args:data_safe_haven tests_}",
"black --check --diff {args:data_safe_haven tests_}",
]
fmt = [
"black {args:data_safe_haven}",
"ruff --fix {args:data_safe_haven}",
"black {args:data_safe_haven tests_}",
"ruff --fix {args:data_safe_haven tests_}",
"style",
]
all = [
Expand Down Expand Up @@ -148,6 +148,10 @@ known-first-party = ["data_safe_haven"]
[tool.ruff.flake8-tidy-imports]
ban-relative-imports = "parents"

[tool.ruff.per-file-ignores]
# Tests can use magic values, assertions, and relative imports
"tests_/**/*" = ["PLR2004", "S101", "TID252"]

[tool.mypy]
disallow_subclassing_any = false # allow subclassing of types from third-party libraries
files = "data_safe_haven" # run mypy over this directory
Expand Down
28 changes: 14 additions & 14 deletions tests_/commands/test_context.py
Original file line number Diff line number Diff line change
@@ -1,10 +1,10 @@
from pytest import fixture
from typer.testing import CliRunner

from data_safe_haven.commands.context import context_command_group
from data_safe_haven.config import Config
from data_safe_haven.context import Context

from pytest import fixture
from typer.testing import CliRunner

context_settings = """\
selected: acme_deployment
contexts:
Expand Down Expand Up @@ -88,7 +88,7 @@ def test_add(self, runner):
"uksouth",
"--subscription",
"Data Safe Haven (Example)",
]
],
)
assert result.exit_code == 0
result = runner.invoke(context_command_group, ["switch", "example"])
Expand All @@ -108,7 +108,7 @@ def test_add_duplicate(self, runner):
"uksouth",
"--subscription",
"Data Safe Haven (Acme)",
]
],
)
assert result.exit_code == 1
# Unable to check error as this is written outside of any Typer
Expand All @@ -128,7 +128,7 @@ def test_add_invalid_uuid(self, runner):
"uksouth",
"--subscription",
"Data Safe Haven (Example)",
]
],
)
assert result.exit_code == 2
# This works because the context_command_group Typer writes this error
Expand All @@ -142,7 +142,7 @@ def test_add_missing_ags(self, runner):
"example",
"--name",
"Example",
]
],
)
assert result.exit_code == 2
assert "Missing option" in result.stderr
Expand All @@ -162,7 +162,7 @@ def test_add_bootstrap(self, tmp_contexts, runner):
"uksouth",
"--subscription",
"Data Safe Haven (Acme)",
]
],
)
assert result.exit_code == 0
assert (tmp_contexts / "contexts.yaml").exists()
Expand Down Expand Up @@ -201,11 +201,11 @@ def test_remove_invalid(self, runner):

class TestCreate:
def test_create(self, runner, monkeypatch):
def mock_create(self):
print("mock create")
def mock_create():
print("mock create") # noqa: T201

def mock_upload(self):
print("mock upload")
def mock_upload():
print("mock upload") # noqa: T201

monkeypatch.setattr(Context, "create", mock_create)
monkeypatch.setattr(Config, "upload", mock_upload)
Expand All @@ -218,8 +218,8 @@ def mock_upload(self):

class TestTeardown:
def test_teardown(self, runner, monkeypatch):
def mock_teardown(self):
print("mock teardown")
def mock_teardown():
print("mock teardown") # noqa: T201

monkeypatch.setattr(Context, "teardown", mock_teardown)

Expand Down
52 changes: 22 additions & 30 deletions tests_/config/test_context_settings.py
Original file line number Diff line number Diff line change
@@ -1,34 +1,22 @@
from data_safe_haven.config.context_settings import Context, ContextSettings
from data_safe_haven.exceptions import DataSafeHavenConfigError, DataSafeHavenParameterError

import pytest
import yaml
from pydantic import ValidationError
from pytest import fixture


@fixture
def context_dict():
return {
"admin_group_id": "d5c5c439-1115-4cb6-ab50-b8e547b6c8dd",
"location": "uksouth",
"name": "Acme Deployment",
"subscription_name": "Data Safe Haven (Acme)"
}


@fixture
def context(context_dict):
return Context(**context_dict)
from data_safe_haven.config.context_settings import Context, ContextSettings
from data_safe_haven.exceptions import (
DataSafeHavenConfigError,
DataSafeHavenParameterError,
)


class TestContext:
def test_constructor(self, context_dict):
context = Context(**context_dict)
assert isinstance(context, Context)
assert all([
assert all(
getattr(context, item) == context_dict[item] for item in context_dict.keys()
])
)
assert context.storage_container_name == "config"

def test_invalid_guid(self, context_dict):
Expand All @@ -44,7 +32,7 @@ def test_invalid_location(self, context_dict):
assert "Value error, Expected valid Azure location" in exc

def test_invalid_subscription_name(self, context_dict):
context_dict["subscription_name"] = "very "*12 + "long name"
context_dict["subscription_name"] = "very " * 12 + "long name"
with pytest.raises(ValidationError) as exc:
Context(**context_dict)
assert "String should have at most 64 characters" in exc
Expand All @@ -65,7 +53,7 @@ def test_storage_account_name(self, context):
assert context.storage_account_name == "shmacmedeploymentcontext"

def test_long_storage_account_name(self, context_dict):
context_dict["name"] = "very "*5 + "long name"
context_dict["name"] = "very " * 5 + "long name"
context = Context(**context_dict)
assert context.storage_account_name == "shmveryveryveryvecontext"

Expand Down Expand Up @@ -119,7 +107,9 @@ def test_missing_selected(self, context_yaml):
assert "Field required" in exc

def test_invalid_selected_input(self, context_yaml):
context_yaml = context_yaml.replace("selected: acme_deployment", "selected: invalid")
context_yaml = context_yaml.replace(
"selected: acme_deployment", "selected: invalid"
)

with pytest.raises(DataSafeHavenParameterError) as exc:
ContextSettings.from_yaml(context_yaml)
Expand Down Expand Up @@ -153,24 +143,26 @@ def test_invalid_selected(self, context_settings):
def test_context(self, context_yaml, context_settings):
yaml_dict = yaml.safe_load(context_yaml)
assert isinstance(context_settings.context, Context)
assert all([
getattr(context_settings.context, item) == yaml_dict["contexts"]["acme_deployment"][item]
assert all(
getattr(context_settings.context, item)
== yaml_dict["contexts"]["acme_deployment"][item]
for item in yaml_dict["contexts"]["acme_deployment"].keys()
])
)

def test_set_context(self, context_yaml, context_settings):
yaml_dict = yaml.safe_load(context_yaml)
context_settings.selected = "gems"
assert isinstance(context_settings.context, Context)
assert all([
getattr(context_settings.context, item) == yaml_dict["contexts"]["gems"][item]
assert all(
getattr(context_settings.context, item)
== yaml_dict["contexts"]["gems"][item]
for item in yaml_dict["contexts"]["gems"].keys()
])
)

def test_available(self, context_settings):
available = context_settings.available
assert isinstance(available, list)
assert all([isinstance(item, str) for item in available])
assert all(isinstance(item, str) for item in available)
assert available == ["acme_deployment", "gems"]

def test_update(self, context_settings):
Expand Down Expand Up @@ -238,7 +230,7 @@ def test_write(self, tmp_path, context_yaml):
settings.selected = "gems"
settings.update(name="replaced")
settings.write(config_file_path)
with open(config_file_path, "r") as f:
with open(config_file_path) as f:
context_dict = yaml.safe_load(f)
assert context_dict["selected"] == "gems"
assert context_dict["contexts"]["gems"]["name"] == "replaced"

0 comments on commit acb01b8

Please sign in to comment.