diff --git a/.github/workflows/lint_code.yaml b/.github/workflows/lint_code.yaml index 56ca0d6b7a..9a24fa80d3 100644 --- a/.github/workflows/lint_code.yaml +++ b/.github/workflows/lint_code.yaml @@ -59,8 +59,10 @@ jobs: python-version: 3.11 - name: Install hatch run: pip install hatch - - name: Print Ruff version - run: hatch run lint:ruff --version + - name: Print package versions + run: | + hatch run lint:ruff --version + hatch run lint:black --version - name: Lint Python run: hatch run lint:all diff --git a/data_safe_haven/README.md b/data_safe_haven/README.md index 07c414dcad..697d3a823f 100644 --- a/data_safe_haven/README.md +++ b/data_safe_haven/README.md @@ -29,7 +29,6 @@ Install the following requirements before starting > dsh deploy shm ``` -You will be prompted for various settings. Run `dsh deploy shm -h` to see the necessary command line flags and provide them as arguments. - Add one or more users from a CSV file with columns named (`GivenName`, `Surname`, `Phone`, `Email`, `CountryCode`). @@ -40,13 +39,20 @@ Run `dsh deploy shm -h` to see the necessary command line flags and provide them > dsh admin add-users ``` -- Next deploy the infrastructure for one or more Secure Research Environments (SREs) [approx 30 minutes]: +- Create the configuration for one or more Secure Research Environments (SREs). + +```console +> dsh config show --file config.yaml +> vim config.yaml +> dsh config upload config.yaml +``` + +- Next deploy the infrastructure [approx 30 minutes]: ```console > dsh deploy sre ``` -You will be prompted for various settings. Run `dsh deploy sre -h` to see the necessary command line flags and provide them as arguments. - Next add one or more existing users to your SRE diff --git a/data_safe_haven/administration/users/user_handler.py b/data_safe_haven/administration/users/user_handler.py index 561b799d36..7e8497bc6f 100644 --- a/data_safe_haven/administration/users/user_handler.py +++ b/data_safe_haven/administration/users/user_handler.py @@ -71,7 +71,7 @@ def get_usernames(self) -> dict[str, list[str]]: usernames = {} usernames["Azure AD"] = self.get_usernames_azure_ad() usernames["Domain controller"] = self.get_usernames_domain_controller() - for sre_name in self.config.sres.keys(): + for sre_name in self.config.sre_names: usernames[f"SRE {sre_name}"] = self.get_usernames_guacamole(sre_name) return usernames diff --git a/data_safe_haven/commands/config.py b/data_safe_haven/commands/config.py index 79c6bc3a79..1accbd75f1 100644 --- a/data_safe_haven/commands/config.py +++ b/data_safe_haven/commands/config.py @@ -41,8 +41,17 @@ def upload( @config_command_group.command() -def show() -> None: +def show( + file: Annotated[ + Optional[Path], # noqa: UP007 + typer.Option(help="File path to write configuration template to."), + ] = None +) -> None: """Print the configuration for the selected Data Safe Haven context""" context = ContextSettings.from_file().assert_context() config = Config.from_remote(context) - print(config.to_yaml()) + if file: + with open(file, "w") as outfile: + outfile.write(config.to_yaml()) + else: + print(config.to_yaml()) diff --git a/data_safe_haven/commands/deploy.py b/data_safe_haven/commands/deploy.py index fb769d7eb9..35fe858a96 100644 --- a/data_safe_haven/commands/deploy.py +++ b/data_safe_haven/commands/deploy.py @@ -10,6 +10,7 @@ from data_safe_haven.functions import bcrypt_salt, password from data_safe_haven.infrastructure import SHMStackManager, SREStackManager from data_safe_haven.provisioning import SHMProvisioningManager, SREProvisioningManager +from data_safe_haven.utility import LoggingSingleton deploy_command_group = typer.Typer() @@ -99,12 +100,17 @@ def sre( ] = None, ) -> None: """Deploy a Secure Research Environment""" + logger = LoggingSingleton() context = ContextSettings.from_file().assert_context() config = Config.from_remote(context) - sre_name = config.sanitise_sre_name(name) try: + # Exit if SRE name is not recognised + if sre_name not in config.sre_names: + logger.fatal(f"Could not find configuration details for SRE '{sre_name}'.") + raise typer.Exit(1) + # Load GraphAPI as this may require user-interaction that is not possible as # part of a Pulumi declarative command graph_api = GraphApi( diff --git a/data_safe_haven/config/config.py b/data_safe_haven/config/config.py index 4b7d36353a..d25eace30e 100644 --- a/data_safe_haven/config/config.py +++ b/data_safe_haven/config/config.py @@ -10,15 +10,12 @@ from pydantic import ( BaseModel, Field, - FieldSerializationInfo, ValidationError, - field_serializer, field_validator, ) from yaml import YAMLError from data_safe_haven import __version__ -from data_safe_haven.config.context_settings import Context from data_safe_haven.exceptions import ( DataSafeHavenConfigError, DataSafeHavenParameterError, @@ -44,6 +41,8 @@ TimeZone, ) +from .context_settings import Context + class ConfigSectionAzure(BaseModel, validate_assignment=True): admin_group_id: Guid = Field(..., exclude=True) @@ -155,7 +154,7 @@ class ConfigSectionSRE(BaseModel, validate_assignment=True): data_provider_ip_addresses: list[IpAddress] = Field( ..., default_factory=list[IpAddress] ) - index: int = Field(..., ge=0) + index: int = Field(..., ge=1, le=256) remote_desktop: ConfigSubsectionRemoteDesktopOpts = Field( ..., default_factory=ConfigSubsectionRemoteDesktopOpts ) @@ -168,20 +167,13 @@ class ConfigSectionSRE(BaseModel, validate_assignment=True): @field_validator("databases") @classmethod def all_databases_must_be_unique( - cls, v: list[DatabaseSystem] + cls, v: list[DatabaseSystem | str] ) -> list[DatabaseSystem]: - if len(v) != len(set(v)): + v_ = [DatabaseSystem(d) for d in v] + if len(v_) != len(set(v_)): msg = "all databases must be unique" raise ValueError(msg) - return v - - @field_serializer("software_packages") - def software_packages_serializer( - self, - packages: SoftwarePackageCategory, - info: FieldSerializationInfo, # noqa: ARG002 - ) -> str: - return packages.value + return v_ def update( self, @@ -260,6 +252,17 @@ def __init__(self, context: Context, **kwargs: dict[Any, Any]): tags = ConfigSectionTags(context) super().__init__(context=context, tags=tags, **kwargs) + @field_validator("sres") + @classmethod + def all_sre_indices_must_be_unique( + cls, v: dict[str, ConfigSectionSRE] + ) -> dict[str, ConfigSectionSRE]: + indices = [s.index for s in v.values()] + if len(indices) != len(set(indices)): + msg = "all SRE indices must be unique" + raise ValueError(msg) + return v + @property def work_directory(self) -> Path: return self.context.work_directory @@ -276,9 +279,15 @@ def pulumi_encryption_key(self) -> KeyVaultKey: @property def pulumi_encryption_key_version(self) -> str: + """ID for the Pulumi encryption key""" key_id: str = self.pulumi_encryption_key.id return key_id.split("/")[-1] + @property + def sre_names(self) -> list[str]: + """Names of all SREs""" + return list(self.sres.keys()) + def is_complete(self, *, require_sres: bool) -> bool: if require_sres: if not self.sres: @@ -293,11 +302,16 @@ def sanitise_sre_name(name: str) -> str: def sre(self, name: str) -> ConfigSectionSRE: """Return the config entry for this SRE, raising an exception if it does not exist""" - if name not in self.sres.keys(): + if name not in self.sre_names: msg = f"SRE {name} does not exist" raise DataSafeHavenConfigError(msg) return self.sres[name] + def remove_sre(self, name: str) -> None: + """Remove SRE config section by name""" + if name in self.sre_names: + del self.sres[name] + def add_stack(self, name: str, path: Path) -> None: """Add a Pulumi stack file to config""" if self.pulumi: @@ -336,6 +350,19 @@ def template(cls, context: Context) -> Config: fqdn="TRE domain name", timezone="Timezone", ), + sres={ + "example": ConfigSectionSRE.model_construct( + databases=["List of database systems to enable"], + data_provider_ip_addresses=["Data provider IP addresses"], + remote_desktop=ConfigSubsectionRemoteDesktopOpts.model_construct( + allow_copy="Whether to allow copying text out of the environment", + allow_paste="Whether to allow pasting text into the environment", + ), + workspace_skus=["Azure VM SKUs"], + research_user_ip_addresses=["Research user IP addresses"], + software_packages=SoftwarePackageCategory.ANY, + ) + }, ) @classmethod @@ -356,7 +383,7 @@ def from_yaml(cls, context: Context, config_yaml: str) -> Config: config_dict[section]["context"] = context try: - return Config.model_validate(config_dict) + return Config.model_validate(config_dict, strict=True) except ValidationError as exc: msg = f"Could not load configuration.\n{exc}" raise DataSafeHavenParameterError(msg) from exc @@ -373,7 +400,7 @@ def from_remote(cls, context: Context) -> Config: return Config.from_yaml(context, config_yaml) def to_yaml(self) -> str: - return yaml.dump(self.model_dump(), indent=2) + return yaml.dump(self.model_dump(mode="json"), indent=2) def upload(self) -> None: """Upload config to Azure storage""" diff --git a/data_safe_haven/external/api/azure_cli.py b/data_safe_haven/external/api/azure_cli.py index 1f9cf3976d..e32a597911 100644 --- a/data_safe_haven/external/api/azure_cli.py +++ b/data_safe_haven/external/api/azure_cli.py @@ -68,10 +68,11 @@ def confirm(self) -> None: return None account = self.account - self.logger.info( - f"name: {account.name} (id: {account.id_}\ntenant: {account.tenant_id})" - ) - if not typer.confirm("Is this the Azure account you expect?\n"): + self.logger.info(f"Azure user: {account.name} ({account.id_})") + self.logger.info(f"Azure tenant ID: {account.tenant_id})") + if not self.logger.confirm( + "Is this the Azure account you expect?", default_to_yes=False + ): self.logger.error( "Please use `az login` to connect to the correct Azure CLI account" ) diff --git a/data_safe_haven/infrastructure/stacks/declarative_sre.py b/data_safe_haven/infrastructure/stacks/declarative_sre.py index b25edec917..7e3141616e 100644 --- a/data_safe_haven/infrastructure/stacks/declarative_sre.py +++ b/data_safe_haven/infrastructure/stacks/declarative_sre.py @@ -90,7 +90,7 @@ def run(self) -> None: shm_networking_resource_group_name=self.pulumi_opts.require( "shm-networking-resource_group_name" ), - sre_index=self.cfg.sres[self.sre_name].index, + sre_index=self.cfg.sre(self.sre_name).index, ), tags=self.cfg.tags.model_dump(), ) @@ -124,11 +124,11 @@ def run(self) -> None: "shm-networking-virtual_network_name" ), shm_zone_name=self.cfg.shm.fqdn, - sre_index=self.cfg.sres[self.sre_name].index, + sre_index=self.cfg.sre(self.sre_name).index, sre_name=self.sre_name, - user_public_ip_ranges=self.cfg.sres[ + user_public_ip_ranges=self.cfg.sre( self.sre_name - ].research_user_ip_addresses, + ).research_user_ip_addresses, ), tags=self.cfg.tags.model_dump(), ) @@ -148,7 +148,7 @@ def run(self) -> None: resource_group_name=self.pulumi_opts.require( "shm-monitoring-resource_group_name" ), - sre_index=self.cfg.sres[self.sre_name].index, + sre_index=self.cfg.sre(self.sre_name).index, timezone=self.cfg.shm.timezone, ), tags=self.cfg.tags.model_dump(), @@ -162,9 +162,9 @@ def run(self) -> None: admin_email_address=self.cfg.shm.admin_email_address, admin_group_id=self.cfg.azure.admin_group_id, admin_ip_addresses=self.cfg.shm.admin_ip_addresses, - data_provider_ip_addresses=self.cfg.sres[ + data_provider_ip_addresses=self.cfg.sre( self.sre_name - ].data_provider_ip_addresses, + ).data_provider_ip_addresses, dns_record=networking.shm_ns_record, dns_server_admin_password=dns.password_admin, location=self.cfg.azure.location, @@ -204,8 +204,8 @@ def run(self) -> None: aad_application_fqdn=networking.sre_fqdn, aad_auth_token=self.graph_api_token, aad_tenant_id=self.cfg.shm.aad_tenant_id, - allow_copy=self.cfg.sres[self.sre_name].remote_desktop.allow_copy, - allow_paste=self.cfg.sres[self.sre_name].remote_desktop.allow_paste, + allow_copy=self.cfg.sre(self.sre_name).remote_desktop.allow_copy, + allow_paste=self.cfg.sre(self.sre_name).remote_desktop.allow_paste, database_password=data.password_user_database_admin, dns_server_ip=dns.ip_address, ldap_bind_dn=ldap_bind_dn, @@ -260,7 +260,7 @@ def run(self) -> None: subscription_name=self.cfg.context.subscription_name, virtual_network_resource_group=networking.resource_group, virtual_network=networking.virtual_network, - vm_details=list(enumerate(self.cfg.sres[self.sre_name].workspace_skus)), + vm_details=list(enumerate(self.cfg.sre(self.sre_name).workspace_skus)), ), tags=self.cfg.tags.model_dump(), ) @@ -271,7 +271,7 @@ def run(self) -> None: self.stack_name, SREUserServicesProps( database_service_admin_password=data.password_database_service_admin, - databases=self.cfg.sres[self.sre_name].databases, + databases=self.cfg.sre(self.sre_name).databases, dns_resource_group_name=dns.resource_group.name, dns_server_ip=dns.ip_address, domain_netbios_name=self.pulumi_opts.require( @@ -288,7 +288,7 @@ def run(self) -> None: location=self.cfg.azure.location, networking_resource_group_name=networking.resource_group.name, nexus_admin_password=data.password_nexus_admin, - software_packages=self.cfg.sres[self.sre_name].software_packages, + software_packages=self.cfg.sre(self.sre_name).software_packages, sre_fqdn=networking.sre_fqdn, sre_private_dns_zone_id=networking.sre_private_dns_zone_id, storage_account_key=data.storage_account_data_configuration_key, diff --git a/tests_/commands/test_config.py b/tests_/commands/test_config.py index a73fb4709d..e27cff164d 100644 --- a/tests_/commands/test_config.py +++ b/tests_/commands/test_config.py @@ -6,10 +6,11 @@ def test_template(self, runner): result = runner.invoke(config_command_group, ["template"]) assert result.exit_code == 0 assert "subscription_id: Azure subscription ID" in result.stdout - assert "sres: {}" in result.stdout + assert "shm:" in result.stdout + assert "sres:" in result.stdout def test_template_file(self, runner, tmp_path): - template_file = (tmp_path / "template.yaml").absolute() + template_file = (tmp_path / "template_create.yaml").absolute() result = runner.invoke( config_command_group, ["template", "--file", str(template_file)] ) @@ -17,7 +18,8 @@ def test_template_file(self, runner, tmp_path): with open(template_file) as f: template_text = f.read() assert "subscription_id: Azure subscription ID" in template_text - assert "sres: {}" in template_text + assert "shm:" in template_text + assert "sres:" in template_text class TestUpload: @@ -41,3 +43,15 @@ def test_show(self, runner, config_yaml, mock_download_blob): # noqa: ARG002 result = runner.invoke(config_command_group, ["show"]) assert result.exit_code == 0 assert config_yaml in result.stdout + + def test_show_file( + self, runner, config_yaml, mock_download_blob, tmp_path # noqa: ARG002 + ): + template_file = (tmp_path / "template_show.yaml").absolute() + result = runner.invoke( + config_command_group, ["show", "--file", str(template_file)] + ) + assert result.exit_code == 0 + with open(template_file) as f: + template_text = f.read() + assert config_yaml in template_text diff --git a/tests_/config/test_config.py b/tests_/config/test_config.py index 42ab6fae9b..d77160f677 100644 --- a/tests_/config/test_config.py +++ b/tests_/config/test_config.py @@ -82,10 +82,11 @@ def test_update(self, shm_config): assert shm_config.fqdn == "shm.example.com" def test_update_validation(self, shm_config): - with pytest.raises(ValidationError) as exc: + with pytest.raises( + ValidationError, + match="Value error, Expected valid email address.*not an email address", + ): shm_config.update(admin_email_address="not an email address") - assert "Value error, Expected valid email address" in exc - assert "not an email address" in exc @fixture @@ -118,7 +119,7 @@ def test_constructor(self, remote_desktop_config): sre_config = ConfigSectionSRE( databases=[DatabaseSystem.POSTGRESQL], data_provider_ip_addresses=["0.0.0.0"], # noqa: S104 - index=0, + index=1, remote_desktop=remote_desktop_config, workspace_skus=["Standard_D2s_v4"], research_user_ip_addresses=["0.0.0.0"], # noqa: S104 @@ -127,7 +128,7 @@ def test_constructor(self, remote_desktop_config): assert sre_config.data_provider_ip_addresses[0] == "0.0.0.0/32" def test_constructor_defaults(self, remote_desktop_config): - sre_config = ConfigSectionSRE(index=0) + sre_config = ConfigSectionSRE(index=1) assert sre_config.databases == [] assert sre_config.data_provider_ip_addresses == [] assert sre_config.remote_desktop == remote_desktop_config @@ -136,15 +137,14 @@ def test_constructor_defaults(self, remote_desktop_config): assert sre_config.software_packages == SoftwarePackageCategory.NONE def test_all_databases_must_be_unique(self): - with pytest.raises(ValueError) as exc: + with pytest.raises(ValueError, match="all databases must be unique"): ConfigSectionSRE( - index=0, + index=1, databases=[DatabaseSystem.POSTGRESQL, DatabaseSystem.POSTGRESQL], ) - assert "all databases must be unique" in exc def test_update(self): - sre_config = ConfigSectionSRE(index=0) + sre_config = ConfigSectionSRE(index=1) assert sre_config.databases == [] assert sre_config.data_provider_ip_addresses == [] assert sre_config.workspace_skus == [] @@ -198,8 +198,8 @@ def config_no_sres(context, azure_config, pulumi_config, shm_config): @fixture def config_sres(context, azure_config, pulumi_config, shm_config): - sre_config_1 = ConfigSectionSRE(index=0) - sre_config_2 = ConfigSectionSRE(index=1) + sre_config_1 = ConfigSectionSRE(index=1) + sre_config_2 = ConfigSectionSRE(index=2) return Config( context=context, azure=azure_config, @@ -236,6 +236,23 @@ def test_constructor(self, context, azure_config, pulumi_config, shm_config): ) assert not config.sres + def test_all_sre_indices_must_be_unique( + self, context, azure_config, pulumi_config, shm_config + ): + with pytest.raises(ValueError, match="all SRE indices must be unique"): + sre_config_1 = ConfigSectionSRE(index=1) + sre_config_2 = ConfigSectionSRE(index=1) + Config( + context=context, + azure=azure_config, + pulumi=pulumi_config, + shm=shm_config, + sres={ + "sre1": sre_config_1, + "sre2": sre_config_2, + }, + ) + def test_work_directory(self, config_sres): config = config_sres assert config.work_directory == config.context.work_directory @@ -270,8 +287,8 @@ def test_sanitise_sre_name(self, value, expected): def test_sre(self, config_sres): sre1, sre2 = config_sres.sre("sre1"), config_sres.sre("sre2") - assert sre1.index == 0 - assert sre2.index == 1 + assert sre1.index == 1 + assert sre2.index == 2 assert sre1 != sre2 def test_sre_invalid(self, config_sres): diff --git a/tests_/config/test_context_settings.py b/tests_/config/test_context_settings.py index fe08df07bf..98afafa8bb 100644 --- a/tests_/config/test_context_settings.py +++ b/tests_/config/test_context_settings.py @@ -21,21 +21,24 @@ def test_constructor(self, context_dict): def test_invalid_guid(self, context_dict): context_dict["admin_group_id"] = "not a guid" - with pytest.raises(ValidationError) as exc: + with pytest.raises( + ValidationError, match="Value error, Expected GUID, for example" + ): Context(**context_dict) - assert "Value error, Expected GUID, for example" in exc def test_invalid_location(self, context_dict): context_dict["location"] = "not_a_location" - with pytest.raises(ValidationError) as exc: + with pytest.raises( + ValidationError, match="Value error, Expected valid Azure location" + ): Context(**context_dict) - assert "Value error, Expected valid Azure location" in exc def test_invalid_subscription_name(self, context_dict): context_dict["subscription_name"] = "very " * 12 + "long name" - with pytest.raises(ValidationError) as exc: + with pytest.raises( + ValidationError, match="String should have at most 64 characters" + ): Context(**context_dict) - assert "String should have at most 64 characters" in exc def test_shm_name(self, context): assert context.shm_name == "acmedeployment" @@ -105,40 +108,47 @@ def test_null_selected(self, context_yaml): settings = ContextSettings.from_yaml(context_yaml) assert settings.selected is None assert settings.context is None - with pytest.raises(DataSafeHavenConfigError) as exc: + with pytest.raises(DataSafeHavenConfigError, match="No context selected"): settings.assert_context() - assert "No context selected" in exc def test_missing_selected(self, context_yaml): context_yaml = "\n".join(context_yaml.splitlines()[1:]) - - with pytest.raises(DataSafeHavenParameterError) as exc: + msg = "\n".join( + [ + "Could not load context settings.", + "1 validation error for ContextSettings", + "selected", + " Field required", + ] + ) + with pytest.raises(DataSafeHavenParameterError, match=msg): ContextSettings.from_yaml(context_yaml) - assert "Could not load context settings" in exc - assert "1 validation error for ContextSettings" in exc - assert "selected" in exc - assert "Field required" in exc def test_invalid_selected_input(self, context_yaml): context_yaml = context_yaml.replace( "selected: acme_deployment", "selected: invalid" ) - with pytest.raises(DataSafeHavenParameterError) as exc: + with pytest.raises( + DataSafeHavenParameterError, + match="Selected context 'invalid' is not defined.", + ): ContextSettings.from_yaml(context_yaml) - assert "Selected context 'invalid' is not defined." in exc def test_invalid_yaml(self): invalid_yaml = "a: [1,2" - with pytest.raises(DataSafeHavenConfigError) as exc: + with pytest.raises( + DataSafeHavenConfigError, match="Could not parse context settings as YAML." + ): ContextSettings.from_yaml(invalid_yaml) - assert "Could not parse context settings as YAML." in exc def test_yaml_not_dict(self): not_dict = "[1, 2, 3]" - with pytest.raises(DataSafeHavenConfigError) as exc: + with pytest.raises( + DataSafeHavenConfigError, + match="Unable to parse context settings as a dict.", + ): ContextSettings.from_yaml(not_dict) - assert "Unable to parse context settings as a dict." in exc def test_selected(self, context_settings): assert context_settings.selected == "acme_deployment" @@ -149,9 +159,10 @@ def test_set_selected(self, context_settings): assert context_settings.selected == "gems" def test_invalid_selected(self, context_settings): - with pytest.raises(DataSafeHavenParameterError) as exc: + with pytest.raises( + DataSafeHavenParameterError, match="Context 'invalid' is not defined." + ): context_settings.selected = "invalid" - assert "Context invalid is not defined." in exc def test_context(self, context_yaml, context_settings): yaml_dict = yaml.safe_load(context_yaml) @@ -183,9 +194,8 @@ def test_assert_context(self, context_settings): def test_assert_context_none(self, context_settings): context_settings.selected = None - with pytest.raises(DataSafeHavenConfigError) as exc: + with pytest.raises(DataSafeHavenConfigError, match="No context selected"): context_settings.assert_context() - assert "No context selected" in exc def test_available(self, context_settings): available = context_settings.available @@ -206,9 +216,8 @@ def test_set_update(self, context_settings): def test_update_none(self, context_settings): context_settings.selected = None - with pytest.raises(DataSafeHavenConfigError) as exc: + with pytest.raises(DataSafeHavenConfigError, match="No context selected"): context_settings.update(name="replaced") - assert "No context selected" in exc def test_add(self, context_settings): context_settings.add( @@ -224,7 +233,10 @@ def test_add(self, context_settings): assert context_settings.context.subscription_name == "Data Safe Haven (Example)" def test_invalid_add(self, context_settings): - with pytest.raises(DataSafeHavenParameterError) as exc: + with pytest.raises( + DataSafeHavenParameterError, + match="A context with key 'acme_deployment' is already defined.", + ): context_settings.add( key="acme_deployment", name="Acme Deployment", @@ -232,7 +244,6 @@ def test_invalid_add(self, context_settings): admin_group_id="d5c5c439-1115-4cb6-ab50-b8e547b6c8dd", location="uksouth", ) - assert "A context with key 'acme' is already defined." in exc def test_remove(self, context_settings): context_settings.remove("gems") @@ -240,9 +251,10 @@ def test_remove(self, context_settings): assert context_settings.selected == "acme_deployment" def test_invalid_remove(self, context_settings): - with pytest.raises(DataSafeHavenParameterError) as exc: + with pytest.raises( + DataSafeHavenParameterError, match="No context with key 'invalid'." + ): context_settings.remove("invalid") - assert "No context with key 'invalid'." in exc def test_remove_selected(self, context_settings): context_settings.remove("acme_deployment") @@ -258,9 +270,8 @@ def test_from_file(self, tmp_path, context_yaml): def test_file_not_found(self, tmp_path): config_file_path = tmp_path / "config.yaml" - with pytest.raises(DataSafeHavenConfigError) as exc: + with pytest.raises(DataSafeHavenConfigError, match="Could not find file"): ContextSettings.from_file(config_file_path=config_file_path) - assert "Could not find file" in exc def test_write(self, tmp_path, context_yaml): config_file_path = tmp_path / "config.yaml" diff --git a/tests_/conftest.py b/tests_/conftest.py index f97a7c2fcb..7188a46f15 100644 --- a/tests_/conftest.py +++ b/tests_/conftest.py @@ -37,7 +37,7 @@ def config_yaml(): sre1: data_provider_ip_addresses: [] databases: [] - index: 0 + index: 1 remote_desktop: allow_copy: false allow_paste: false @@ -47,7 +47,7 @@ def config_yaml(): sre2: data_provider_ip_addresses: [] databases: [] - index: 1 + index: 2 remote_desktop: allow_copy: false allow_paste: false diff --git a/tests_/external/api/azure_api.py b/tests_/external/api/azure_api.py index 1eab649234..2f4215841b 100644 --- a/tests_/external/api/azure_api.py +++ b/tests_/external/api/azure_api.py @@ -32,6 +32,7 @@ def test_get_keyvault_key(self, mock_key_client): # noqa: ARG002 def test_get_keyvault_key_missing(self, mock_key_client): # noqa: ARG002 api = AzureApi("subscription name") - with pytest.raises(DataSafeHavenAzureError) as exc: + with pytest.raises( + DataSafeHavenAzureError, match="Failed to retrieve key does not exist" + ): api.get_keyvault_key("does not exist", "key vault name") - assert "Failed to retrieve key does not exist" in exc diff --git a/tests_/functions/test_typer_validators.py b/tests_/functions/test_typer_validators.py index 4cba74fc29..64364e3f7d 100644 --- a/tests_/functions/test_typer_validators.py +++ b/tests_/functions/test_typer_validators.py @@ -23,9 +23,8 @@ def test_typer_validate_aad_guid(self, guid): ], ) def test_typer_validate_aad_guid_fail(self, guid): - with pytest.raises(BadParameter) as exc: + with pytest.raises(BadParameter, match="Expected GUID"): typer_validate_aad_guid(guid) - assert "Expected GUID" in exc def test_typer_validate_aad_guid_nonae(self): assert typer_validate_aad_guid(None) is None diff --git a/tests_/functions/test_validators.py b/tests_/functions/test_validators.py index 2588f2d24a..b7408fc826 100644 --- a/tests_/functions/test_validators.py +++ b/tests_/functions/test_validators.py @@ -22,9 +22,8 @@ def test_validate_aad_guid(self, guid): ], ) def test_validate_aad_guid_fail(self, guid): - with pytest.raises(ValueError) as exc: + with pytest.raises(ValueError, match="Expected GUID"): validate_aad_guid(guid) - assert "Expected GUID" in exc class TestValidateFqdn: @@ -50,6 +49,7 @@ def test_validate_fqdn(self, fqdn): ], ) def test_validate_fqdn_fail(self, fqdn): - with pytest.raises(ValueError) as exc: + with pytest.raises( + ValueError, match="Expected valid fully qualified domain name" + ): validate_fqdn(fqdn) - assert "Expected valid fully qualified domain name" in exc