Skip to content

Commit

Permalink
Merge pull request #1881 from alan-turing-institute/upload_check
Browse files Browse the repository at this point in the history
Protect against configuration changes
  • Loading branch information
JimMadge authored May 15, 2024
2 parents 9e5b72a + 0724e38 commit 9a4fdcc
Show file tree
Hide file tree
Showing 6 changed files with 202 additions and 20 deletions.
21 changes: 21 additions & 0 deletions data_safe_haven/commands/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@

from data_safe_haven.config import Config
from data_safe_haven.context import ContextSettings
from data_safe_haven.utility import LoggingSingleton

config_command_group = typer.Typer()

Expand All @@ -34,9 +35,29 @@ def upload(
) -> None:
"""Upload a configuration to the Data Safe Haven context"""
context = ContextSettings.from_file().assert_context()
logger = LoggingSingleton()

# Create configuration object from file
with open(file) as config_file:
config_yaml = config_file.read()
config = Config.from_yaml(config_yaml)

# Present diff to user
if Config.remote_exists(context):
if diff := config.remote_yaml_diff(context):
print("".join(diff))
if not logger.confirm(
(
"Configuration has changed, "
"do you want to overwrite the remote configuration?"
),
default_to_yes=False,
):
raise typer.Exit()
else:
print("No changes, won't upload configuration.")
raise typer.Exit()

config.upload(context)


Expand Down
25 changes: 20 additions & 5 deletions data_safe_haven/serialisers/azure_serialisable_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,16 +36,31 @@ def from_remote_or_create(
Construct an AzureSerialisableModel from a YAML file in Azure storage, or from
default arguments if no such file exists.
"""
if cls.remote_exists(context):
return cls.from_remote(context)
else:
return cls(**default_args)

@classmethod
def remote_exists(cls: type[T], context: ContextBase) -> bool:
"""Check whether a remote instance of this model exists."""
azure_api = AzureApi(subscription_name=context.subscription_name)
if azure_api.blob_exists(
return azure_api.blob_exists(
cls.filename,
context.resource_group_name,
context.storage_account_name,
context.storage_container_name,
):
return cls.from_remote(context)
else:
return cls(**default_args)
)

def remote_yaml_diff(self, context: ContextBase) -> list[str]:
"""
Determine the diff of YAML output from the remote model to `self`.
The diff is given in unified diff format.
"""
remote_model = self.from_remote(context)

return self.yaml_diff(remote_model, from_name="remote", to_name="local")

def upload(self, context: ContextBase) -> None:
"""Serialise an AzureSerialisableModel to a YAML file in Azure storage."""
Expand Down
18 changes: 18 additions & 0 deletions data_safe_haven/serialisers/yaml_serialisable_model.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
"""A pydantic BaseModel that can be serialised to and from YAML"""

from difflib import unified_diff
from pathlib import Path
from typing import ClassVar, TypeVar

Expand Down Expand Up @@ -64,3 +65,20 @@ def to_filepath(self, config_file_path: PathType) -> None:
def to_yaml(self) -> str:
"""Serialise a YAMLSerialisableModel to a YAML string"""
return yaml.dump(self.model_dump(by_alias=True, mode="json"), indent=2)

def yaml_diff(
self, other: T, from_name: str = "other", to_name: str = "self"
) -> list[str]:
"""
Determine the diff of YAML output from `other` to `self`.
The diff is given in unified diff format.
"""
return list(
unified_diff(
other.to_yaml().splitlines(keepends=True),
self.to_yaml().splitlines(keepends=True),
fromfile=from_name,
tofile=to_name,
)
)
76 changes: 73 additions & 3 deletions tests/commands/test_config.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,22 +25,92 @@ def test_template_file(self, runner, tmp_path):


class TestUpload:
def test_upload(self, mocker, context, runner, config_yaml, config_file):
mock_method = mocker.patch.object(AzureApi, "upload_blob", return_value=None)
def test_upload_new(self, mocker, context, runner, config_yaml, config_file):
mock_exists = mocker.patch.object(Config, "remote_exists", return_value=False)
mock_upload = mocker.patch.object(AzureApi, "upload_blob", return_value=None)
result = runner.invoke(
config_command_group,
["upload", str(config_file)],
)
assert result.exit_code == 0

mock_method.assert_called_once_with(
mock_exists.assert_called_once_with(context)
mock_upload.assert_called_once_with(
config_yaml,
Config.filename,
context.resource_group_name,
context.storage_account_name,
context.storage_container_name,
)

def test_upload_no_changes(self, mocker, context, runner, config_sres, config_file):
mock_exists = mocker.patch.object(Config, "remote_exists", return_value=True)
mock_from_remote = mocker.patch.object(
Config, "from_remote", return_value=config_sres
)
mock_upload = mocker.patch.object(AzureApi, "upload_blob", return_value=None)
result = runner.invoke(
config_command_group,
["upload", str(config_file)],
)
assert result.exit_code == 0

mock_exists.assert_called_once_with(context)
mock_from_remote.assert_called_once_with(context)
mock_upload.assert_not_called()

assert "No changes, won't upload configuration." in result.stdout

def test_upload_changes(
self, mocker, context, runner, config_no_sres, config_file, config_yaml
):
mock_exists = mocker.patch.object(Config, "remote_exists", return_value=True)
mock_from_remote = mocker.patch.object(
Config, "from_remote", return_value=config_no_sres
)
mock_upload = mocker.patch.object(AzureApi, "upload_blob", return_value=None)
result = runner.invoke(
config_command_group,
["upload", str(config_file)],
input="y\n",
)
assert result.exit_code == 0

mock_exists.assert_called_once_with(context)
mock_from_remote.assert_called_once_with(context)
mock_upload.assert_called_once_with(
config_yaml,
Config.filename,
context.resource_group_name,
context.storage_account_name,
context.storage_container_name,
)

assert "--- remote" in result.stdout
assert "+++ local" in result.stdout

def test_upload_changes_n(
self, mocker, context, runner, config_no_sres, config_file
):
mock_exists = mocker.patch.object(Config, "remote_exists", return_value=True)
mock_from_remote = mocker.patch.object(
Config, "from_remote", return_value=config_no_sres
)
mock_upload = mocker.patch.object(AzureApi, "upload_blob", return_value=None)
result = runner.invoke(
config_command_group,
["upload", str(config_file)],
input="n\n",
)
assert result.exit_code == 0

mock_exists.assert_called_once_with(context)
mock_from_remote.assert_called_once_with(context)
mock_upload.assert_not_called()

assert "--- remote" in result.stdout
assert "+++ local" in result.stdout

def test_upload_no_file(self, mocker, runner):
mocker.patch.object(AzureApi, "upload_blob", return_value=None)
result = runner.invoke(
Expand Down
42 changes: 37 additions & 5 deletions tests/serialisers/test_azure_serialisable_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,7 @@ def example_config_class():

@fixture
def example_config_yaml():
return "\n".join(["string: 'abc'", "integer: -3", "list_of_integers: [-1,0,1]"])
return "\n".join(["string: 'hello'", "integer: 5", "list_of_integers: [1,2,3]"])


class TestAzureSerialisableModel:
Expand All @@ -34,6 +34,38 @@ def test_constructor(self, example_config_class):
assert isinstance(example_config_class, AzureSerialisableModel)
assert example_config_class.string == "hello"

def test_remote_yaml_diff(self, mocker, example_config_class, context):
mocker.patch.object(
AzureApi, "download_blob", return_value=example_config_class.to_yaml()
)
diff = example_config_class.remote_yaml_diff(context)
assert not diff
assert diff == []

def test_remote_yaml_diff_difference(self, mocker, example_config_class, context):
mocker.patch.object(
AzureApi, "download_blob", return_value=example_config_class.to_yaml()
)
example_config_class.integer = 0
example_config_class.string = "abc"

diff = example_config_class.remote_yaml_diff(context)

assert isinstance(diff, list)
assert diff == [
"--- remote\n",
"+++ local\n",
"@@ -1,6 +1,6 @@\n",
"-integer: 5\n",
"+integer: 0\n",
" list_of_integers:\n",
" - 1\n",
" - 2\n",
" - 3\n",
"-string: hello\n",
"+string: abc\n",
]

def test_to_yaml(self, example_config_class):
yaml = example_config_class.to_yaml()
assert isinstance(yaml, str)
Expand All @@ -59,9 +91,9 @@ def test_from_yaml(self, example_config_yaml):
)
assert isinstance(example_config_class, ExampleAzureSerialisableModel)
assert isinstance(example_config_class, AzureSerialisableModel)
assert example_config_class.string == "abc"
assert example_config_class.integer == -3
assert example_config_class.list_of_integers == [-1, 0, 1]
assert example_config_class.string == "hello"
assert example_config_class.integer == 5
assert example_config_class.list_of_integers == [1, 2, 3]

def test_from_yaml_invalid_yaml(self):
yaml = "\n".join(["string: 'abc'", "integer: -3", "list_of_integers: [-1,0,1"])
Expand Down Expand Up @@ -97,7 +129,7 @@ def test_from_remote(self, mocker, context, example_config_yaml):
example_config = ExampleAzureSerialisableModel.from_remote(context)

assert isinstance(example_config, ExampleAzureSerialisableModel)
assert example_config.string == "abc"
assert example_config.string == "hello"

mock_method.assert_called_once_with(
"file.yaml",
Expand Down
40 changes: 33 additions & 7 deletions tests/serialisers/test_yaml_serialisable_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,7 @@ def example_config_class():

@fixture
def example_config_yaml():
return "\n".join(["string: 'abc'", "integer: -3", "list_of_integers: [-1,0,1]"])
return "\n".join(["string: 'hello'", "integer: 5", "list_of_integers: [1,2,3]"])


class TestYAMLSerialisableModel:
Expand All @@ -38,19 +38,19 @@ def test_from_filepath(self, tmp_path, example_config_yaml):
example_config_class = ExampleYAMLSerialisableModel.from_filepath(filepath)
assert isinstance(example_config_class, ExampleYAMLSerialisableModel)
assert isinstance(example_config_class, YAMLSerialisableModel)
assert example_config_class.string == "abc"
assert example_config_class.integer == -3
assert example_config_class.list_of_integers == [-1, 0, 1]
assert example_config_class.string == "hello"
assert example_config_class.integer == 5
assert example_config_class.list_of_integers == [1, 2, 3]

def test_from_yaml(self, example_config_yaml):
example_config_class = ExampleYAMLSerialisableModel.from_yaml(
example_config_yaml
)
assert isinstance(example_config_class, ExampleYAMLSerialisableModel)
assert isinstance(example_config_class, YAMLSerialisableModel)
assert example_config_class.string == "abc"
assert example_config_class.integer == -3
assert example_config_class.list_of_integers == [-1, 0, 1]
assert example_config_class.string == "hello"
assert example_config_class.integer == 5
assert example_config_class.list_of_integers == [1, 2, 3]

def test_from_yaml_invalid_yaml(self):
yaml = "\n".join(["string: 'abc'", "integer: -3", "list_of_integers: [-1,0,1"])
Expand Down Expand Up @@ -95,3 +95,29 @@ def test_to_yaml(self, example_config_class):
assert "string: hello" in yaml
assert "integer: 5" in yaml
assert "config_type" not in yaml

def test_yaml_diff(self, example_config_class):
other = example_config_class.model_copy(deep=True)
diff = example_config_class.yaml_diff(other)
assert not diff
assert diff == []

def test_yaml_diff_difference(self, example_config_class):
other = example_config_class.model_copy(deep=True)
other.integer = 3
other.string = "abc"
diff = example_config_class.yaml_diff(other)
assert isinstance(diff, list)
assert diff == [
"--- other\n",
"+++ self\n",
"@@ -1,6 +1,6 @@\n",
"-integer: 3\n",
"+integer: 5\n",
" list_of_integers:\n",
" - 1\n",
" - 2\n",
" - 3\n",
"-string: abc\n",
"+string: hello\n",
]

0 comments on commit 9a4fdcc

Please sign in to comment.