diff --git a/data_safe_haven/config/config.py b/data_safe_haven/config/config.py index 552395a498..3e68f9205d 100644 --- a/data_safe_haven/config/config.py +++ b/data_safe_haven/config/config.py @@ -5,7 +5,7 @@ from typing import ClassVar import yaml -from pydantic import BaseModel, Field, ValidationError, computed_field, field_validator +from pydantic import BaseModel, Field, ValidationError, field_validator from yaml import YAMLError from data_safe_haven import __version__ @@ -251,12 +251,22 @@ class Config(BaseModel, validate_assignment=True): pulumi: ConfigSectionPulumi | None = None shm: ConfigSectionSHM | None = None tags: ConfigSectionTags | None = None - sres: dict[str, ConfigSectionSRE] | None = None + sres: dict[str, ConfigSectionSRE] = Field( + default_factory=dict[str, ConfigSectionSRE] + ) - @computed_field + @property def work_directory(self) -> str: return self.context.work_directory + def is_complete(self, *, require_sres: bool) -> bool: + if require_sres: + if not self.sres: + return False + if not all((self.azure, self.pulumi, self.shm, self.tags)): + return False + return True + def read_stack(self, name: str, path: pathlib.Path) -> None: """Add a Pulumi stack file to config""" with open(path, encoding="utf-8") as f_stack: diff --git a/tests_/config/test_config.py b/tests_/config/test_config.py index 2b55b8b827..5542735cc1 100644 --- a/tests_/config/test_config.py +++ b/tests_/config/test_config.py @@ -3,6 +3,7 @@ from pytest import fixture from data_safe_haven.config.config import ( + Config, ConfigSectionAzure, ConfigSectionPulumi, ConfigSectionSHM, @@ -14,6 +15,15 @@ from data_safe_haven.version import __version__ +@fixture +def azure_config(context): + return ConfigSectionAzure.from_context( + context=context, + subscription_id="d5c5c439-1115-4cb6-ab50-b8e547b6c8dd", + tenant_id="d5c5c439-1115-4cb6-ab50-b8e547b6c8dd", + ) + + class TestConfigSectionAzure: def test_constructor(self): ConfigSectionAzure( @@ -32,6 +42,11 @@ def test_from_context(self, context): assert azure_config.location == context.location +@fixture +def pulumi_config(): + return ConfigSectionPulumi(encryption_key_version="lorem") + + class TestConfigSectionPulumi: def test_constructor_defaults(self): pulumi_config = ConfigSectionPulumi(encryption_key_version="lorem") @@ -41,13 +56,13 @@ def test_constructor_defaults(self): @fixture -def shm_config(): - return ConfigSectionSHM( +def shm_config(context): + return ConfigSectionSHM.from_context( + context=context, aad_tenant_id="d5c5c439-1115-4cb6-ab50-b8e547b6c8dd", admin_email_address="admin@example.com", admin_ip_addresses=["0.0.0.0"], # noqa: S104 fqdn="shm.acme.com", - name="ACME SHM", timezone="UTC", ) @@ -98,16 +113,16 @@ def test_constructor(self): def test_constructor_defaults(self): remote_desktop_config = ConfigSubsectionRemoteDesktopOpts() assert not all( - [remote_desktop_config.allow_copy, remote_desktop_config.allow_paste] + (remote_desktop_config.allow_copy, remote_desktop_config.allow_paste) ) def test_update(self, remote_desktop_config): assert not all( - [remote_desktop_config.allow_copy, remote_desktop_config.allow_paste] + (remote_desktop_config.allow_copy, remote_desktop_config.allow_paste) ) remote_desktop_config.update(allow_copy=True, allow_paste=True) assert all( - [remote_desktop_config.allow_copy, remote_desktop_config.allow_paste] + (remote_desktop_config.allow_copy, remote_desktop_config.allow_paste) ) @@ -162,6 +177,11 @@ def test_update(self): assert sre_config.software_packages == SoftwarePackageCategory.ANY +@fixture +def tags_config(context): + return ConfigSectionTags.from_context(context) + + class TestConfigSectionTags: def test_constructor(self): tags_config = ConfigSectionTags(deployment="Test Deployment") @@ -176,3 +196,63 @@ def test_from_context(self, context): assert tags_config.deployed_by == "Python" assert tags_config.project == "Data Safe Haven" assert tags_config.version == __version__ + + +@fixture +def config_no_sres(context, azure_config, pulumi_config, shm_config, tags_config): + return Config( + context=context, + azure=azure_config, + pulumi=pulumi_config, + shm=shm_config, + tags=tags_config + ) + + +@fixture +def config_sres(context, azure_config, pulumi_config, shm_config, tags_config): + sre_config_1 = ConfigSectionSRE(index=0) + sre_config_2 = ConfigSectionSRE(index=1) + return Config( + context=context, + azure=azure_config, + pulumi=pulumi_config, + shm=shm_config, + sres={ + "sre1": sre_config_1, + "sre2": sre_config_2, + }, + tags=tags_config + ) + + +class TestConfig: + def test_constructor_defaults(self, context): + config = Config(context=context) + assert config.context == context + assert not any( + (config.azure, config.pulumi, config.shm, config.tags, config.sres) + ) + + @pytest.mark.parametrize("require_sres", [False, True]) + def test_is_complete_bare(self, context, require_sres): + config = Config(context=context) + assert config.is_complete(require_sres=require_sres) is False + + def test_constructor(self, context, azure_config, pulumi_config, shm_config, tags_config): + config = Config( + context=context, + azure=azure_config, + pulumi=pulumi_config, + shm=shm_config, + tags=tags_config + ) + assert not config.sres + + @pytest.mark.parametrize("require_sres,expected", [(False, True), (True, False)]) + def test_is_complete_no_sres(self, config_no_sres, require_sres, expected): + assert config_no_sres.is_complete(require_sres=require_sres) is expected + + @pytest.mark.parametrize("require_sres", [False, True]) + def test_is_complete_sres(self, config_sres, require_sres): + assert config_sres.is_complete(require_sres=require_sres)