diff --git a/data_safe_haven/config/config.py b/data_safe_haven/config/config.py index 76ff4a4f61..6e1c222dd3 100644 --- a/data_safe_haven/config/config.py +++ b/data_safe_haven/config/config.py @@ -1,7 +1,7 @@ """Configuration file backed by blob storage""" from __future__ import annotations -import pathlib +from pathlib import Path from typing import ClassVar import yaml @@ -282,28 +282,34 @@ def is_complete(self, *, require_sres: bool) -> bool: return False return True - def read_stack(self, name: str, path: pathlib.Path) -> None: - """Add a Pulumi stack file to config""" - with open(path, encoding="utf-8") as f_stack: - pulumi_cfg = f_stack.read() - self.pulumi.stacks[name] = b64encode(pulumi_cfg) + def sre(self, name: str) -> ConfigSectionSRE: + """Return the config entry for this SRE creating it if it does not exist""" + if name not in self.sres.keys(): + highest_index = max(0 + sre.index for sre in self.sres.values()) + self.sres[name].index = highest_index + 1 + return self.sres[name] def remove_sre(self, name: str) -> None: """Remove SRE config section by name""" if name in self.sres.keys(): del self.sres[name] + def add_stack(self, name: str, path: Path) -> None: + """Add a Pulumi stack file to config""" + with open(path, encoding="utf-8") as f_stack: + pulumi_cfg = f_stack.read() + self.pulumi.stacks[name] = b64encode(pulumi_cfg) + def remove_stack(self, name: str) -> None: """Remove Pulumi stack section by name""" if name in self.pulumi.stacks.keys(): del self.pulumi.stacks[name] - def sre(self, name: str) -> ConfigSectionSRE: - """Return the config entry for this SRE creating it if it does not exist""" - if name not in self.sres.keys(): - highest_index = max([0] + [sre.index for sre in self.sres.values()]) - self.sres[name].index = highest_index + 1 - return self.sres[name] + def write_stack(self, name: str, path: Path) -> None: + """Write a Pulumi stack file from config""" + pulumi_cfg = b64decode(self.pulumi.stacks[name]) + with open(path, "w", encoding="utf-8") as f_stack: + f_stack.write(pulumi_cfg) @classmethod def from_yaml(cls, config_yaml: str) -> Config: @@ -347,9 +353,3 @@ def upload(self) -> None: self.context.storage_account_name, self.context.storage_container_name, ) - - def write_stack(self, name: str, path: pathlib.Path) -> None: - """Write a Pulumi stack file from config""" - pulumi_cfg = b64decode(self.pulumi.stacks[name]) - with open(path, "w", encoding="utf-8") as f_stack: - f_stack.write(pulumi_cfg) diff --git a/tests_/config/test_config.py b/tests_/config/test_config.py index 426f325585..452aa24c04 100644 --- a/tests_/config/test_config.py +++ b/tests_/config/test_config.py @@ -286,11 +286,6 @@ def test_constructor_defaults(self, context): (config.azure, config.pulumi, config.shm, config.tags, config.sres) ) - @pytest.mark.parametrize("require_sres", [False, True]) - def test_is_complete_bare(self, context, require_sres): - config = Config(context=context) - assert config.is_complete(require_sres=require_sres) is False - def test_constructor( self, context, azure_config, pulumi_config, shm_config, tags_config ): @@ -303,6 +298,15 @@ def test_constructor( ) assert not config.sres + def test_work_directory(self, config_sres): + config = config_sres + assert config.work_directory == config.context.work_directory + + @pytest.mark.parametrize("require_sres", [False, True]) + def test_is_complete_bare(self, context, require_sres): + config = Config(context=context) + assert config.is_complete(require_sres=require_sres) is False + @pytest.mark.parametrize("require_sres,expected", [(False, True), (True, False)]) def test_is_complete_no_sres(self, config_no_sres, require_sres, expected): assert config_no_sres.is_complete(require_sres=require_sres) is expected @@ -311,13 +315,6 @@ def test_is_complete_no_sres(self, config_no_sres, require_sres, expected): def test_is_complete_sres(self, config_sres, require_sres): assert config_sres.is_complete(require_sres=require_sres) - def test_work_directory(self, config_sres): - config = config_sres - assert config.work_directory == config.context.work_directory - - def test_to_yaml(self, config_sres, config_yaml): - assert config_sres.to_yaml() == config_yaml - def test_from_yaml(self, config_sres, config_yaml): config = Config.from_yaml(config_yaml) assert config == config_sres @@ -325,20 +322,6 @@ def test_from_yaml(self, config_sres, config_yaml): config.sres["sre1"].software_packages, SoftwarePackageCategory ) - def test_upload(self, config_sres, monkeypatch): - def mock_upload_blob( - self, # noqa: ARG001 - blob_data: bytes | str, # noqa: ARG001 - blob_name: str, # noqa: ARG001 - resource_group_name: str, # noqa: ARG001 - storage_account_name: str, # noqa: ARG001 - storage_container_name: str, # noqa: ARG001 - ): - pass - - monkeypatch.setattr(AzureApi, "upload_blob", mock_upload_blob) - config_sres.upload() - def test_from_remote(self, context, config_sres, config_yaml, monkeypatch): def mock_download_blob( self, # noqa: ARG001 @@ -356,3 +339,20 @@ def mock_download_blob( monkeypatch.setattr(AzureApi, "download_blob", mock_download_blob) config = Config.from_remote(context) assert config == config_sres + + def test_to_yaml(self, config_sres, config_yaml): + assert config_sres.to_yaml() == config_yaml + + def test_upload(self, config_sres, monkeypatch): + def mock_upload_blob( + self, # noqa: ARG001 + blob_data: bytes | str, # noqa: ARG001 + blob_name: str, # noqa: ARG001 + resource_group_name: str, # noqa: ARG001 + storage_account_name: str, # noqa: ARG001 + storage_container_name: str, # noqa: ARG001 + ): + pass + + monkeypatch.setattr(AzureApi, "upload_blob", mock_upload_blob) + config_sres.upload()