Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fix some issues with context handling at deployment time #1716

Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
16 commits
Select commit Hold shift + click to select a range
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 4 additions & 2 deletions .github/workflows/lint_code.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -59,8 +59,10 @@ jobs:
python-version: 3.11
- name: Install hatch
run: pip install hatch
- name: Print Ruff version
run: hatch run lint:ruff --version
- name: Print package versions
run: |
hatch run lint:ruff --version
hatch run lint:black --version
- name: Lint Python
run: hatch run lint:all

Expand Down
12 changes: 9 additions & 3 deletions data_safe_haven/README.md
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,6 @@ Install the following requirements before starting
> dsh deploy shm
```

You will be prompted for various settings.
Run `dsh deploy shm -h` to see the necessary command line flags and provide them as arguments.

- Add one or more users from a CSV file with columns named (`GivenName`, `Surname`, `Phone`, `Email`, `CountryCode`).
Expand All @@ -40,13 +39,20 @@ Run `dsh deploy shm -h` to see the necessary command line flags and provide them
> dsh admin add-users <my CSV users file>
```

- Next deploy the infrastructure for one or more Secure Research Environments (SREs) [approx 30 minutes]:
- Create the configuration for one or more Secure Research Environments (SREs).

```console
> dsh config show --file config.yaml
> vim config.yaml
> dsh config upload config.yaml
```

- Next deploy the infrastructure [approx 30 minutes]:

```console
> dsh deploy sre <SRE name>
```

You will be prompted for various settings.
Run `dsh deploy sre -h` to see the necessary command line flags and provide them as arguments.

- Next add one or more existing users to your SRE
Expand Down
2 changes: 1 addition & 1 deletion data_safe_haven/administration/users/user_handler.py
Original file line number Diff line number Diff line change
Expand Up @@ -71,7 +71,7 @@ def get_usernames(self) -> dict[str, list[str]]:
usernames = {}
usernames["Azure AD"] = self.get_usernames_azure_ad()
usernames["Domain controller"] = self.get_usernames_domain_controller()
for sre_name in self.config.sres.keys():
for sre_name in self.config.sre_names:
usernames[f"SRE {sre_name}"] = self.get_usernames_guacamole(sre_name)
return usernames

Expand Down
13 changes: 11 additions & 2 deletions data_safe_haven/commands/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -41,8 +41,17 @@ def upload(


@config_command_group.command()
def show() -> None:
def show(
file: Annotated[
Optional[Path], # noqa: UP007
typer.Option(help="File path to write configuration template to."),
] = None
) -> None:
"""Print the configuration for the selected Data Safe Haven context"""
context = ContextSettings.from_file().assert_context()
config = Config.from_remote(context)
print(config.to_yaml())
if file:
with open(file, "w") as outfile:
outfile.write(config.to_yaml())
else:
print(config.to_yaml())
jemrobinson marked this conversation as resolved.
Show resolved Hide resolved
8 changes: 7 additions & 1 deletion data_safe_haven/commands/deploy.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@
from data_safe_haven.functions import bcrypt_salt, password
from data_safe_haven.infrastructure import SHMStackManager, SREStackManager
from data_safe_haven.provisioning import SHMProvisioningManager, SREProvisioningManager
from data_safe_haven.utility import LoggingSingleton

deploy_command_group = typer.Typer()

Expand Down Expand Up @@ -99,12 +100,17 @@ def sre(
] = None,
) -> None:
"""Deploy a Secure Research Environment"""
logger = LoggingSingleton()
context = ContextSettings.from_file().assert_context()
config = Config.from_remote(context)

sre_name = config.sanitise_sre_name(name)

try:
# Exit if SRE name is not recognised
if sre_name not in config.sre_names:
logger.fatal(f"Could not find configuration details for SRE '{sre_name}'.")
raise typer.Exit(1)

# Load GraphAPI as this may require user-interaction that is not possible as
# part of a Pulumi declarative command
graph_api = GraphApi(
Expand Down
63 changes: 45 additions & 18 deletions data_safe_haven/config/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,15 +10,12 @@
from pydantic import (
BaseModel,
Field,
FieldSerializationInfo,
ValidationError,
field_serializer,
field_validator,
)
from yaml import YAMLError

from data_safe_haven import __version__
from data_safe_haven.config.context_settings import Context
from data_safe_haven.exceptions import (
DataSafeHavenConfigError,
DataSafeHavenParameterError,
Expand All @@ -44,6 +41,8 @@
TimeZone,
)

from .context_settings import Context


class ConfigSectionAzure(BaseModel, validate_assignment=True):
admin_group_id: Guid = Field(..., exclude=True)
Expand Down Expand Up @@ -155,7 +154,7 @@ class ConfigSectionSRE(BaseModel, validate_assignment=True):
data_provider_ip_addresses: list[IpAddress] = Field(
..., default_factory=list[IpAddress]
)
index: int = Field(..., ge=0)
index: int = Field(..., ge=1, le=256)
remote_desktop: ConfigSubsectionRemoteDesktopOpts = Field(
..., default_factory=ConfigSubsectionRemoteDesktopOpts
)
Expand All @@ -168,20 +167,13 @@ class ConfigSectionSRE(BaseModel, validate_assignment=True):
@field_validator("databases")
@classmethod
def all_databases_must_be_unique(
cls, v: list[DatabaseSystem]
cls, v: list[DatabaseSystem | str]
) -> list[DatabaseSystem]:
if len(v) != len(set(v)):
v_ = [DatabaseSystem(d) for d in v]
if len(v_) != len(set(v_)):
msg = "all databases must be unique"
raise ValueError(msg)
return v

@field_serializer("software_packages")
def software_packages_serializer(
self,
packages: SoftwarePackageCategory,
info: FieldSerializationInfo, # noqa: ARG002
) -> str:
return packages.value
JimMadge marked this conversation as resolved.
Show resolved Hide resolved
return v_

def update(
self,
Expand Down Expand Up @@ -260,6 +252,17 @@ def __init__(self, context: Context, **kwargs: dict[Any, Any]):
tags = ConfigSectionTags(context)
super().__init__(context=context, tags=tags, **kwargs)

@field_validator("sres")
@classmethod
def all_sre_indices_must_be_unique(
cls, v: dict[str, ConfigSectionSRE]
) -> dict[str, ConfigSectionSRE]:
indices = [s.index for s in v.values()]
if len(indices) != len(set(indices)):
msg = "all SRE indices must be unique"
raise ValueError(msg)
return v
jemrobinson marked this conversation as resolved.
Show resolved Hide resolved
jemrobinson marked this conversation as resolved.
Show resolved Hide resolved

@property
def work_directory(self) -> Path:
return self.context.work_directory
Expand All @@ -276,9 +279,15 @@ def pulumi_encryption_key(self) -> KeyVaultKey:

@property
def pulumi_encryption_key_version(self) -> str:
"""ID for the Pulumi encryption key"""
key_id: str = self.pulumi_encryption_key.id
return key_id.split("/")[-1]

@property
def sre_names(self) -> list[str]:
"""Names of all SREs"""
return list(self.sres.keys())

jemrobinson marked this conversation as resolved.
Show resolved Hide resolved
def is_complete(self, *, require_sres: bool) -> bool:
if require_sres:
if not self.sres:
Expand All @@ -293,11 +302,16 @@ def sanitise_sre_name(name: str) -> str:

def sre(self, name: str) -> ConfigSectionSRE:
"""Return the config entry for this SRE, raising an exception if it does not exist"""
if name not in self.sres.keys():
if name not in self.sre_names:
msg = f"SRE {name} does not exist"
raise DataSafeHavenConfigError(msg)
return self.sres[name]

def remove_sre(self, name: str) -> None:
"""Remove SRE config section by name"""
if name in self.sre_names:
del self.sres[name]

def add_stack(self, name: str, path: Path) -> None:
"""Add a Pulumi stack file to config"""
if self.pulumi:
Expand Down Expand Up @@ -336,6 +350,19 @@ def template(cls, context: Context) -> Config:
fqdn="TRE domain name",
timezone="Timezone",
),
sres={
"example": ConfigSectionSRE.model_construct(
databases=["List of database systems to enable"],
data_provider_ip_addresses=["Data provider IP addresses"],
remote_desktop=ConfigSubsectionRemoteDesktopOpts.model_construct(
allow_copy="Whether to allow copying text out of the environment",
allow_paste="Whether to allow pasting text into the environment",
),
workspace_skus=["Azure VM SKUs"],
research_user_ip_addresses=["Research user IP addresses"],
software_packages=SoftwarePackageCategory.ANY,
)
},
)

@classmethod
Expand All @@ -356,7 +383,7 @@ def from_yaml(cls, context: Context, config_yaml: str) -> Config:
config_dict[section]["context"] = context

try:
return Config.model_validate(config_dict)
return Config.model_validate(config_dict, strict=True)
except ValidationError as exc:
msg = f"Could not load configuration.\n{exc}"
raise DataSafeHavenParameterError(msg) from exc
Expand All @@ -373,7 +400,7 @@ def from_remote(cls, context: Context) -> Config:
return Config.from_yaml(context, config_yaml)

def to_yaml(self) -> str:
return yaml.dump(self.model_dump(), indent=2)
return yaml.dump(self.model_dump(mode="json"), indent=2)

def upload(self) -> None:
"""Upload config to Azure storage"""
Expand Down
9 changes: 5 additions & 4 deletions data_safe_haven/external/api/azure_cli.py
Original file line number Diff line number Diff line change
Expand Up @@ -68,10 +68,11 @@ def confirm(self) -> None:
return None

account = self.account
self.logger.info(
f"name: {account.name} (id: {account.id_}\ntenant: {account.tenant_id})"
)
if not typer.confirm("Is this the Azure account you expect?\n"):
self.logger.info(f"Azure user: {account.name} ({account.id_})")
self.logger.info(f"Azure tenant ID: {account.tenant_id})")
if not self.logger.confirm(
"Is this the Azure account you expect?", default_to_yes=False
):
JimMadge marked this conversation as resolved.
Show resolved Hide resolved
self.logger.error(
"Please use `az login` to connect to the correct Azure CLI account"
)
Expand Down
24 changes: 12 additions & 12 deletions data_safe_haven/infrastructure/stacks/declarative_sre.py
Original file line number Diff line number Diff line change
Expand Up @@ -90,7 +90,7 @@ def run(self) -> None:
shm_networking_resource_group_name=self.pulumi_opts.require(
"shm-networking-resource_group_name"
),
sre_index=self.cfg.sres[self.sre_name].index,
sre_index=self.cfg.sre(self.sre_name).index,
),
tags=self.cfg.tags.model_dump(),
)
Expand Down Expand Up @@ -124,11 +124,11 @@ def run(self) -> None:
"shm-networking-virtual_network_name"
),
shm_zone_name=self.cfg.shm.fqdn,
sre_index=self.cfg.sres[self.sre_name].index,
sre_index=self.cfg.sre(self.sre_name).index,
sre_name=self.sre_name,
user_public_ip_ranges=self.cfg.sres[
user_public_ip_ranges=self.cfg.sre(
self.sre_name
].research_user_ip_addresses,
).research_user_ip_addresses,
),
tags=self.cfg.tags.model_dump(),
)
Expand All @@ -148,7 +148,7 @@ def run(self) -> None:
resource_group_name=self.pulumi_opts.require(
"shm-monitoring-resource_group_name"
),
sre_index=self.cfg.sres[self.sre_name].index,
sre_index=self.cfg.sre(self.sre_name).index,
timezone=self.cfg.shm.timezone,
),
tags=self.cfg.tags.model_dump(),
Expand All @@ -162,9 +162,9 @@ def run(self) -> None:
admin_email_address=self.cfg.shm.admin_email_address,
admin_group_id=self.cfg.azure.admin_group_id,
admin_ip_addresses=self.cfg.shm.admin_ip_addresses,
data_provider_ip_addresses=self.cfg.sres[
data_provider_ip_addresses=self.cfg.sre(
self.sre_name
].data_provider_ip_addresses,
).data_provider_ip_addresses,
dns_record=networking.shm_ns_record,
dns_server_admin_password=dns.password_admin,
location=self.cfg.azure.location,
Expand Down Expand Up @@ -204,8 +204,8 @@ def run(self) -> None:
aad_application_fqdn=networking.sre_fqdn,
aad_auth_token=self.graph_api_token,
aad_tenant_id=self.cfg.shm.aad_tenant_id,
allow_copy=self.cfg.sres[self.sre_name].remote_desktop.allow_copy,
allow_paste=self.cfg.sres[self.sre_name].remote_desktop.allow_paste,
allow_copy=self.cfg.sre(self.sre_name).remote_desktop.allow_copy,
allow_paste=self.cfg.sre(self.sre_name).remote_desktop.allow_paste,
database_password=data.password_user_database_admin,
dns_server_ip=dns.ip_address,
ldap_bind_dn=ldap_bind_dn,
Expand Down Expand Up @@ -260,7 +260,7 @@ def run(self) -> None:
subscription_name=self.cfg.context.subscription_name,
virtual_network_resource_group=networking.resource_group,
virtual_network=networking.virtual_network,
vm_details=list(enumerate(self.cfg.sres[self.sre_name].workspace_skus)),
vm_details=list(enumerate(self.cfg.sre(self.sre_name).workspace_skus)),
),
tags=self.cfg.tags.model_dump(),
)
Expand All @@ -271,7 +271,7 @@ def run(self) -> None:
self.stack_name,
SREUserServicesProps(
database_service_admin_password=data.password_database_service_admin,
databases=self.cfg.sres[self.sre_name].databases,
databases=self.cfg.sre(self.sre_name).databases,
dns_resource_group_name=dns.resource_group.name,
dns_server_ip=dns.ip_address,
domain_netbios_name=self.pulumi_opts.require(
Expand All @@ -288,7 +288,7 @@ def run(self) -> None:
location=self.cfg.azure.location,
networking_resource_group_name=networking.resource_group.name,
nexus_admin_password=data.password_nexus_admin,
software_packages=self.cfg.sres[self.sre_name].software_packages,
software_packages=self.cfg.sre(self.sre_name).software_packages,
sre_fqdn=networking.sre_fqdn,
sre_private_dns_zone_id=networking.sre_private_dns_zone_id,
storage_account_key=data.storage_account_data_configuration_key,
Expand Down
20 changes: 17 additions & 3 deletions tests_/commands/test_config.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,18 +6,20 @@ def test_template(self, runner):
result = runner.invoke(config_command_group, ["template"])
assert result.exit_code == 0
assert "subscription_id: Azure subscription ID" in result.stdout
assert "sres: {}" in result.stdout
assert "shm:" in result.stdout
assert "sres:" in result.stdout

def test_template_file(self, runner, tmp_path):
template_file = (tmp_path / "template.yaml").absolute()
template_file = (tmp_path / "template_create.yaml").absolute()
result = runner.invoke(
config_command_group, ["template", "--file", str(template_file)]
)
assert result.exit_code == 0
with open(template_file) as f:
template_text = f.read()
assert "subscription_id: Azure subscription ID" in template_text
assert "sres: {}" in template_text
assert "shm:" in template_text
assert "sres:" in template_text


class TestUpload:
Expand All @@ -41,3 +43,15 @@ def test_show(self, runner, config_yaml, mock_download_blob): # noqa: ARG002
result = runner.invoke(config_command_group, ["show"])
assert result.exit_code == 0
assert config_yaml in result.stdout

def test_show_file(
self, runner, config_yaml, mock_download_blob, tmp_path # noqa: ARG002
):
template_file = (tmp_path / "template_show.yaml").absolute()
result = runner.invoke(
config_command_group, ["show", "--file", str(template_file)]
)
assert result.exit_code == 0
with open(template_file) as f:
template_text = f.read()
assert config_yaml in template_text
Loading
Loading