Skip to content

Commit

Permalink
Add basic tests for Config
Browse files Browse the repository at this point in the history
  • Loading branch information
JimMadge committed Nov 22, 2023
1 parent e7ca180 commit e8e2e5b
Show file tree
Hide file tree
Showing 2 changed files with 99 additions and 9 deletions.
16 changes: 13 additions & 3 deletions data_safe_haven/config/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@
from typing import ClassVar

import yaml
from pydantic import BaseModel, Field, ValidationError, computed_field, field_validator
from pydantic import BaseModel, Field, ValidationError, field_validator
from yaml import YAMLError

from data_safe_haven import __version__
Expand Down Expand Up @@ -251,12 +251,22 @@ class Config(BaseModel, validate_assignment=True):
pulumi: ConfigSectionPulumi | None = None
shm: ConfigSectionSHM | None = None
tags: ConfigSectionTags | None = None
sres: dict[str, ConfigSectionSRE] | None = None
sres: dict[str, ConfigSectionSRE] = Field(
default_factory=dict[str, ConfigSectionSRE]
)

@computed_field
@property
def work_directory(self) -> str:
return self.context.work_directory

def is_complete(self, *, require_sres: bool) -> bool:
if require_sres:
if not self.sres:
return False
if not all((self.azure, self.pulumi, self.shm, self.tags)):
return False
return True

def read_stack(self, name: str, path: pathlib.Path) -> None:
"""Add a Pulumi stack file to config"""
with open(path, encoding="utf-8") as f_stack:
Expand Down
92 changes: 86 additions & 6 deletions tests_/config/test_config.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
from pytest import fixture

from data_safe_haven.config.config import (
Config,
ConfigSectionAzure,
ConfigSectionPulumi,
ConfigSectionSHM,
Expand All @@ -14,6 +15,15 @@
from data_safe_haven.version import __version__


@fixture
def azure_config(context):
return ConfigSectionAzure.from_context(
context=context,
subscription_id="d5c5c439-1115-4cb6-ab50-b8e547b6c8dd",
tenant_id="d5c5c439-1115-4cb6-ab50-b8e547b6c8dd",
)


class TestConfigSectionAzure:
def test_constructor(self):
ConfigSectionAzure(
Expand All @@ -32,6 +42,11 @@ def test_from_context(self, context):
assert azure_config.location == context.location


@fixture
def pulumi_config():
return ConfigSectionPulumi(encryption_key_version="lorem")


class TestConfigSectionPulumi:
def test_constructor_defaults(self):
pulumi_config = ConfigSectionPulumi(encryption_key_version="lorem")
Expand All @@ -41,13 +56,13 @@ def test_constructor_defaults(self):


@fixture
def shm_config():
return ConfigSectionSHM(
def shm_config(context):
return ConfigSectionSHM.from_context(
context=context,
aad_tenant_id="d5c5c439-1115-4cb6-ab50-b8e547b6c8dd",
admin_email_address="admin@example.com",
admin_ip_addresses=["0.0.0.0"], # noqa: S104
fqdn="shm.acme.com",
name="ACME SHM",
timezone="UTC",
)

Expand Down Expand Up @@ -98,16 +113,16 @@ def test_constructor(self):
def test_constructor_defaults(self):
remote_desktop_config = ConfigSubsectionRemoteDesktopOpts()
assert not all(
[remote_desktop_config.allow_copy, remote_desktop_config.allow_paste]
(remote_desktop_config.allow_copy, remote_desktop_config.allow_paste)
)

def test_update(self, remote_desktop_config):
assert not all(
[remote_desktop_config.allow_copy, remote_desktop_config.allow_paste]
(remote_desktop_config.allow_copy, remote_desktop_config.allow_paste)
)
remote_desktop_config.update(allow_copy=True, allow_paste=True)
assert all(
[remote_desktop_config.allow_copy, remote_desktop_config.allow_paste]
(remote_desktop_config.allow_copy, remote_desktop_config.allow_paste)
)


Expand Down Expand Up @@ -162,6 +177,11 @@ def test_update(self):
assert sre_config.software_packages == SoftwarePackageCategory.ANY


@fixture
def tags_config(context):
return ConfigSectionTags.from_context(context)


class TestConfigSectionTags:
def test_constructor(self):
tags_config = ConfigSectionTags(deployment="Test Deployment")
Expand All @@ -176,3 +196,63 @@ def test_from_context(self, context):
assert tags_config.deployed_by == "Python"
assert tags_config.project == "Data Safe Haven"
assert tags_config.version == __version__


@fixture
def config_no_sres(context, azure_config, pulumi_config, shm_config, tags_config):
return Config(
context=context,
azure=azure_config,
pulumi=pulumi_config,
shm=shm_config,
tags=tags_config
)


@fixture
def config_sres(context, azure_config, pulumi_config, shm_config, tags_config):
sre_config_1 = ConfigSectionSRE(index=0)
sre_config_2 = ConfigSectionSRE(index=1)
return Config(
context=context,
azure=azure_config,
pulumi=pulumi_config,
shm=shm_config,
sres={
"sre1": sre_config_1,
"sre2": sre_config_2,
},
tags=tags_config
)


class TestConfig:
def test_constructor_defaults(self, context):
config = Config(context=context)
assert config.context == context
assert not any(
(config.azure, config.pulumi, config.shm, config.tags, config.sres)
)

@pytest.mark.parametrize("require_sres", [False, True])
def test_is_complete_bare(self, context, require_sres):
config = Config(context=context)
assert config.is_complete(require_sres=require_sres) is False

def test_constructor(self, context, azure_config, pulumi_config, shm_config, tags_config):
config = Config(
context=context,
azure=azure_config,
pulumi=pulumi_config,
shm=shm_config,
tags=tags_config
)
assert not config.sres

@pytest.mark.parametrize("require_sres,expected", [(False, True), (True, False)])
def test_is_complete_no_sres(self, config_no_sres, require_sres, expected):
assert config_no_sres.is_complete(require_sres=require_sres) is expected

@pytest.mark.parametrize("require_sres", [False, True])
def test_is_complete_sres(self, config_sres, require_sres):
assert config_sres.is_complete(require_sres=require_sres)

0 comments on commit e8e2e5b

Please sign in to comment.