Skip to content

Commit

Permalink
Merge pull request #1716 from jemrobinson/1709-fix-context-issues
Browse files Browse the repository at this point in the history
Fix some issues with context handling at deployment time
  • Loading branch information
jemrobinson authored Jan 29, 2024
2 parents 81847af + 65828b5 commit 66cd779
Show file tree
Hide file tree
Showing 15 changed files with 194 additions and 101 deletions.
6 changes: 4 additions & 2 deletions .github/workflows/lint_code.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -59,8 +59,10 @@ jobs:
python-version: 3.11
- name: Install hatch
run: pip install hatch
- name: Print Ruff version
run: hatch run lint:ruff --version
- name: Print package versions
run: |
hatch run lint:ruff --version
hatch run lint:black --version
- name: Lint Python
run: hatch run lint:all

Expand Down
12 changes: 9 additions & 3 deletions data_safe_haven/README.md
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,6 @@ Install the following requirements before starting
> dsh deploy shm
```

You will be prompted for various settings.
Run `dsh deploy shm -h` to see the necessary command line flags and provide them as arguments.

- Add one or more users from a CSV file with columns named (`GivenName`, `Surname`, `Phone`, `Email`, `CountryCode`).
Expand All @@ -40,13 +39,20 @@ Run `dsh deploy shm -h` to see the necessary command line flags and provide them
> dsh admin add-users <my CSV users file>
```

- Next deploy the infrastructure for one or more Secure Research Environments (SREs) [approx 30 minutes]:
- Create the configuration for one or more Secure Research Environments (SREs).

```console
> dsh config show --file config.yaml
> vim config.yaml
> dsh config upload config.yaml
```

- Next deploy the infrastructure [approx 30 minutes]:

```console
> dsh deploy sre <SRE name>
```

You will be prompted for various settings.
Run `dsh deploy sre -h` to see the necessary command line flags and provide them as arguments.

- Next add one or more existing users to your SRE
Expand Down
2 changes: 1 addition & 1 deletion data_safe_haven/administration/users/user_handler.py
Original file line number Diff line number Diff line change
Expand Up @@ -71,7 +71,7 @@ def get_usernames(self) -> dict[str, list[str]]:
usernames = {}
usernames["Azure AD"] = self.get_usernames_azure_ad()
usernames["Domain controller"] = self.get_usernames_domain_controller()
for sre_name in self.config.sres.keys():
for sre_name in self.config.sre_names:
usernames[f"SRE {sre_name}"] = self.get_usernames_guacamole(sre_name)
return usernames

Expand Down
13 changes: 11 additions & 2 deletions data_safe_haven/commands/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -41,8 +41,17 @@ def upload(


@config_command_group.command()
def show() -> None:
def show(
file: Annotated[
Optional[Path], # noqa: UP007
typer.Option(help="File path to write configuration template to."),
] = None
) -> None:
"""Print the configuration for the selected Data Safe Haven context"""
context = ContextSettings.from_file().assert_context()
config = Config.from_remote(context)
print(config.to_yaml())
if file:
with open(file, "w") as outfile:
outfile.write(config.to_yaml())
else:
print(config.to_yaml())
8 changes: 7 additions & 1 deletion data_safe_haven/commands/deploy.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@
from data_safe_haven.functions import bcrypt_salt, password
from data_safe_haven.infrastructure import SHMStackManager, SREStackManager
from data_safe_haven.provisioning import SHMProvisioningManager, SREProvisioningManager
from data_safe_haven.utility import LoggingSingleton

deploy_command_group = typer.Typer()

Expand Down Expand Up @@ -99,12 +100,17 @@ def sre(
] = None,
) -> None:
"""Deploy a Secure Research Environment"""
logger = LoggingSingleton()
context = ContextSettings.from_file().assert_context()
config = Config.from_remote(context)

sre_name = config.sanitise_sre_name(name)

try:
# Exit if SRE name is not recognised
if sre_name not in config.sre_names:
logger.fatal(f"Could not find configuration details for SRE '{sre_name}'.")
raise typer.Exit(1)

# Load GraphAPI as this may require user-interaction that is not possible as
# part of a Pulumi declarative command
graph_api = GraphApi(
Expand Down
63 changes: 45 additions & 18 deletions data_safe_haven/config/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,15 +10,12 @@
from pydantic import (
BaseModel,
Field,
FieldSerializationInfo,
ValidationError,
field_serializer,
field_validator,
)
from yaml import YAMLError

from data_safe_haven import __version__
from data_safe_haven.config.context_settings import Context
from data_safe_haven.exceptions import (
DataSafeHavenConfigError,
DataSafeHavenParameterError,
Expand All @@ -44,6 +41,8 @@
TimeZone,
)

from .context_settings import Context


class ConfigSectionAzure(BaseModel, validate_assignment=True):
admin_group_id: Guid = Field(..., exclude=True)
Expand Down Expand Up @@ -155,7 +154,7 @@ class ConfigSectionSRE(BaseModel, validate_assignment=True):
data_provider_ip_addresses: list[IpAddress] = Field(
..., default_factory=list[IpAddress]
)
index: int = Field(..., ge=0)
index: int = Field(..., ge=1, le=256)
remote_desktop: ConfigSubsectionRemoteDesktopOpts = Field(
..., default_factory=ConfigSubsectionRemoteDesktopOpts
)
Expand All @@ -168,20 +167,13 @@ class ConfigSectionSRE(BaseModel, validate_assignment=True):
@field_validator("databases")
@classmethod
def all_databases_must_be_unique(
cls, v: list[DatabaseSystem]
cls, v: list[DatabaseSystem | str]
) -> list[DatabaseSystem]:
if len(v) != len(set(v)):
v_ = [DatabaseSystem(d) for d in v]
if len(v_) != len(set(v_)):
msg = "all databases must be unique"
raise ValueError(msg)
return v

@field_serializer("software_packages")
def software_packages_serializer(
self,
packages: SoftwarePackageCategory,
info: FieldSerializationInfo, # noqa: ARG002
) -> str:
return packages.value
return v_

def update(
self,
Expand Down Expand Up @@ -260,6 +252,17 @@ def __init__(self, context: Context, **kwargs: dict[Any, Any]):
tags = ConfigSectionTags(context)
super().__init__(context=context, tags=tags, **kwargs)

@field_validator("sres")
@classmethod
def all_sre_indices_must_be_unique(
cls, v: dict[str, ConfigSectionSRE]
) -> dict[str, ConfigSectionSRE]:
indices = [s.index for s in v.values()]
if len(indices) != len(set(indices)):
msg = "all SRE indices must be unique"
raise ValueError(msg)
return v

@property
def work_directory(self) -> Path:
return self.context.work_directory
Expand All @@ -276,9 +279,15 @@ def pulumi_encryption_key(self) -> KeyVaultKey:

@property
def pulumi_encryption_key_version(self) -> str:
"""ID for the Pulumi encryption key"""
key_id: str = self.pulumi_encryption_key.id
return key_id.split("/")[-1]

@property
def sre_names(self) -> list[str]:
"""Names of all SREs"""
return list(self.sres.keys())

def is_complete(self, *, require_sres: bool) -> bool:
if require_sres:
if not self.sres:
Expand All @@ -293,11 +302,16 @@ def sanitise_sre_name(name: str) -> str:

def sre(self, name: str) -> ConfigSectionSRE:
"""Return the config entry for this SRE, raising an exception if it does not exist"""
if name not in self.sres.keys():
if name not in self.sre_names:
msg = f"SRE {name} does not exist"
raise DataSafeHavenConfigError(msg)
return self.sres[name]

def remove_sre(self, name: str) -> None:
"""Remove SRE config section by name"""
if name in self.sre_names:
del self.sres[name]

def add_stack(self, name: str, path: Path) -> None:
"""Add a Pulumi stack file to config"""
if self.pulumi:
Expand Down Expand Up @@ -336,6 +350,19 @@ def template(cls, context: Context) -> Config:
fqdn="TRE domain name",
timezone="Timezone",
),
sres={
"example": ConfigSectionSRE.model_construct(
databases=["List of database systems to enable"],
data_provider_ip_addresses=["Data provider IP addresses"],
remote_desktop=ConfigSubsectionRemoteDesktopOpts.model_construct(
allow_copy="Whether to allow copying text out of the environment",
allow_paste="Whether to allow pasting text into the environment",
),
workspace_skus=["Azure VM SKUs"],
research_user_ip_addresses=["Research user IP addresses"],
software_packages=SoftwarePackageCategory.ANY,
)
},
)

@classmethod
Expand All @@ -356,7 +383,7 @@ def from_yaml(cls, context: Context, config_yaml: str) -> Config:
config_dict[section]["context"] = context

try:
return Config.model_validate(config_dict)
return Config.model_validate(config_dict, strict=True)
except ValidationError as exc:
msg = f"Could not load configuration.\n{exc}"
raise DataSafeHavenParameterError(msg) from exc
Expand All @@ -373,7 +400,7 @@ def from_remote(cls, context: Context) -> Config:
return Config.from_yaml(context, config_yaml)

def to_yaml(self) -> str:
return yaml.dump(self.model_dump(), indent=2)
return yaml.dump(self.model_dump(mode="json"), indent=2)

def upload(self) -> None:
"""Upload config to Azure storage"""
Expand Down
9 changes: 5 additions & 4 deletions data_safe_haven/external/api/azure_cli.py
Original file line number Diff line number Diff line change
Expand Up @@ -68,10 +68,11 @@ def confirm(self) -> None:
return None

account = self.account
self.logger.info(
f"name: {account.name} (id: {account.id_}\ntenant: {account.tenant_id})"
)
if not typer.confirm("Is this the Azure account you expect?\n"):
self.logger.info(f"Azure user: {account.name} ({account.id_})")
self.logger.info(f"Azure tenant ID: {account.tenant_id})")
if not self.logger.confirm(
"Is this the Azure account you expect?", default_to_yes=False
):
self.logger.error(
"Please use `az login` to connect to the correct Azure CLI account"
)
Expand Down
24 changes: 12 additions & 12 deletions data_safe_haven/infrastructure/stacks/declarative_sre.py
Original file line number Diff line number Diff line change
Expand Up @@ -90,7 +90,7 @@ def run(self) -> None:
shm_networking_resource_group_name=self.pulumi_opts.require(
"shm-networking-resource_group_name"
),
sre_index=self.cfg.sres[self.sre_name].index,
sre_index=self.cfg.sre(self.sre_name).index,
),
tags=self.cfg.tags.model_dump(),
)
Expand Down Expand Up @@ -124,11 +124,11 @@ def run(self) -> None:
"shm-networking-virtual_network_name"
),
shm_zone_name=self.cfg.shm.fqdn,
sre_index=self.cfg.sres[self.sre_name].index,
sre_index=self.cfg.sre(self.sre_name).index,
sre_name=self.sre_name,
user_public_ip_ranges=self.cfg.sres[
user_public_ip_ranges=self.cfg.sre(
self.sre_name
].research_user_ip_addresses,
).research_user_ip_addresses,
),
tags=self.cfg.tags.model_dump(),
)
Expand All @@ -148,7 +148,7 @@ def run(self) -> None:
resource_group_name=self.pulumi_opts.require(
"shm-monitoring-resource_group_name"
),
sre_index=self.cfg.sres[self.sre_name].index,
sre_index=self.cfg.sre(self.sre_name).index,
timezone=self.cfg.shm.timezone,
),
tags=self.cfg.tags.model_dump(),
Expand All @@ -162,9 +162,9 @@ def run(self) -> None:
admin_email_address=self.cfg.shm.admin_email_address,
admin_group_id=self.cfg.azure.admin_group_id,
admin_ip_addresses=self.cfg.shm.admin_ip_addresses,
data_provider_ip_addresses=self.cfg.sres[
data_provider_ip_addresses=self.cfg.sre(
self.sre_name
].data_provider_ip_addresses,
).data_provider_ip_addresses,
dns_record=networking.shm_ns_record,
dns_server_admin_password=dns.password_admin,
location=self.cfg.azure.location,
Expand Down Expand Up @@ -204,8 +204,8 @@ def run(self) -> None:
aad_application_fqdn=networking.sre_fqdn,
aad_auth_token=self.graph_api_token,
aad_tenant_id=self.cfg.shm.aad_tenant_id,
allow_copy=self.cfg.sres[self.sre_name].remote_desktop.allow_copy,
allow_paste=self.cfg.sres[self.sre_name].remote_desktop.allow_paste,
allow_copy=self.cfg.sre(self.sre_name).remote_desktop.allow_copy,
allow_paste=self.cfg.sre(self.sre_name).remote_desktop.allow_paste,
database_password=data.password_user_database_admin,
dns_server_ip=dns.ip_address,
ldap_bind_dn=ldap_bind_dn,
Expand Down Expand Up @@ -260,7 +260,7 @@ def run(self) -> None:
subscription_name=self.cfg.context.subscription_name,
virtual_network_resource_group=networking.resource_group,
virtual_network=networking.virtual_network,
vm_details=list(enumerate(self.cfg.sres[self.sre_name].workspace_skus)),
vm_details=list(enumerate(self.cfg.sre(self.sre_name).workspace_skus)),
),
tags=self.cfg.tags.model_dump(),
)
Expand All @@ -271,7 +271,7 @@ def run(self) -> None:
self.stack_name,
SREUserServicesProps(
database_service_admin_password=data.password_database_service_admin,
databases=self.cfg.sres[self.sre_name].databases,
databases=self.cfg.sre(self.sre_name).databases,
dns_resource_group_name=dns.resource_group.name,
dns_server_ip=dns.ip_address,
domain_netbios_name=self.pulumi_opts.require(
Expand All @@ -288,7 +288,7 @@ def run(self) -> None:
location=self.cfg.azure.location,
networking_resource_group_name=networking.resource_group.name,
nexus_admin_password=data.password_nexus_admin,
software_packages=self.cfg.sres[self.sre_name].software_packages,
software_packages=self.cfg.sre(self.sre_name).software_packages,
sre_fqdn=networking.sre_fqdn,
sre_private_dns_zone_id=networking.sre_private_dns_zone_id,
storage_account_key=data.storage_account_data_configuration_key,
Expand Down
20 changes: 17 additions & 3 deletions tests_/commands/test_config.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,18 +6,20 @@ def test_template(self, runner):
result = runner.invoke(config_command_group, ["template"])
assert result.exit_code == 0
assert "subscription_id: Azure subscription ID" in result.stdout
assert "sres: {}" in result.stdout
assert "shm:" in result.stdout
assert "sres:" in result.stdout

def test_template_file(self, runner, tmp_path):
template_file = (tmp_path / "template.yaml").absolute()
template_file = (tmp_path / "template_create.yaml").absolute()
result = runner.invoke(
config_command_group, ["template", "--file", str(template_file)]
)
assert result.exit_code == 0
with open(template_file) as f:
template_text = f.read()
assert "subscription_id: Azure subscription ID" in template_text
assert "sres: {}" in template_text
assert "shm:" in template_text
assert "sres:" in template_text


class TestUpload:
Expand All @@ -41,3 +43,15 @@ def test_show(self, runner, config_yaml, mock_download_blob): # noqa: ARG002
result = runner.invoke(config_command_group, ["show"])
assert result.exit_code == 0
assert config_yaml in result.stdout

def test_show_file(
self, runner, config_yaml, mock_download_blob, tmp_path # noqa: ARG002
):
template_file = (tmp_path / "template_show.yaml").absolute()
result = runner.invoke(
config_command_group, ["show", "--file", str(template_file)]
)
assert result.exit_code == 0
with open(template_file) as f:
template_text = f.read()
assert config_yaml in template_text
Loading

0 comments on commit 66cd779

Please sign in to comment.