diff --git a/data_safe_haven/config/config_sections.py b/data_safe_haven/config/config_sections.py index 252c94e7d0..35b9570a7e 100644 --- a/data_safe_haven/config/config_sections.py +++ b/data_safe_haven/config/config_sections.py @@ -10,6 +10,7 @@ from data_safe_haven.types import ( AzureLocation, AzurePremiumFileShareSize, + AzureServiceTag, AzureVmSku, DatabaseSystem, EmailAddress, @@ -58,7 +59,7 @@ class ConfigSectionSRE(BaseModel, validate_assignment=True): databases: UniqueList[DatabaseSystem] = [] data_provider_ip_addresses: list[IpAddress] = [] remote_desktop: ConfigSubsectionRemoteDesktopOpts - research_user_ip_addresses: list[IpAddress] = [] + research_user_ip_addresses: list[IpAddress] | AzureServiceTag = [] storage_quota_gb: ConfigSubsectionStorageQuotaGB software_packages: SoftwarePackageCategory = SoftwarePackageCategory.NONE timezone: TimeZone = "Etc/UTC" @@ -67,7 +68,7 @@ class ConfigSectionSRE(BaseModel, validate_assignment=True): @field_validator( "admin_ip_addresses", "data_provider_ip_addresses", - "research_user_ip_addresses", + # "research_user_ip_addresses", mode="after", ) @classmethod @@ -78,3 +79,16 @@ def ensure_non_overlapping(cls, v: list[IpAddress]) -> list[IpAddress]: msg = "IP addresses must not overlap." raise ValueError(msg) return v + + @field_validator( + "research_user_ip_addresses", + mode="after", + ) + @classmethod + def ensure_non_overlapping_or_tag( + cls, v: list[IpAddress] | AzureServiceTag + ) -> list[IpAddress] | AzureServiceTag: + if isinstance(v, list): + return cls.ensure_non_overlapping(v) + else: + return v diff --git a/data_safe_haven/config/sre_config.py b/data_safe_haven/config/sre_config.py index f4ee5ed6c9..9fba89e12f 100644 --- a/data_safe_haven/config/sre_config.py +++ b/data_safe_haven/config/sre_config.py @@ -98,7 +98,10 @@ def template(cls: type[Self], tier: int | None = None) -> SREConfig: allow_copy=remote_desktop_allow_copy, allow_paste=remote_desktop_allow_paste, ), - research_user_ip_addresses=["List of IP addresses belonging to users"], + research_user_ip_addresses=[ + "List of IP addresses belonging to users", + "You can also use the tag 'Internet' instead of a list", + ], software_packages=software_packages, storage_quota_gb=ConfigSubsectionStorageQuotaGB.model_construct( home="Total size in GiB across all home directories [minimum: 100].", # type: ignore diff --git a/data_safe_haven/infrastructure/programs/sre/networking.py b/data_safe_haven/infrastructure/programs/sre/networking.py index 42e1345c2d..e6c308f587 100644 --- a/data_safe_haven/infrastructure/programs/sre/networking.py +++ b/data_safe_haven/infrastructure/programs/sre/networking.py @@ -12,7 +12,7 @@ get_id_from_vnet, get_name_from_vnet, ) -from data_safe_haven.types import NetworkingPriorities, Ports +from data_safe_haven.types import AzureServiceTag, NetworkingPriorities, Ports class SRENetworkingProps: @@ -31,7 +31,7 @@ def __init__( shm_subscription_id: Input[str], shm_zone_name: Input[str], sre_name: Input[str], - user_public_ip_ranges: Input[list[str]], + user_public_ip_ranges: Input[list[str]] | AzureServiceTag, ) -> None: # Other variables self.dns_private_zones = dns_private_zones @@ -68,6 +68,13 @@ def __init__( child_opts = ResourceOptions.merge(opts, ResourceOptions(parent=self)) child_tags = {"component": "networking"} | (tags if tags else {}) + if isinstance(props.user_public_ip_ranges, list): + user_public_ip_ranges = props.user_public_ip_ranges + user_service_tag = None + else: + user_public_ip_ranges = None + user_service_tag = props.user_public_ip_ranges + # Define route table route_table = network.RouteTable( f"{self._name}_route_table", @@ -125,7 +132,8 @@ def __init__( name="AllowUsersInternetInbound", priority=NetworkingPriorities.AUTHORISED_EXTERNAL_USER_IPS, protocol=network.SecurityRuleProtocol.TCP, - source_address_prefixes=props.user_public_ip_ranges, + source_address_prefix=user_service_tag, + source_address_prefixes=user_public_ip_ranges, source_port_range="*", ), network.SecurityRuleArgs( diff --git a/data_safe_haven/types/__init__.py b/data_safe_haven/types/__init__.py index 471fb56656..4f2f89b3be 100644 --- a/data_safe_haven/types/__init__.py +++ b/data_safe_haven/types/__init__.py @@ -15,6 +15,7 @@ from .enums import ( AzureDnsZoneNames, AzureSdkCredentialScope, + AzureServiceTag, DatabaseSystem, FirewallPriorities, ForbiddenDomains, @@ -29,6 +30,7 @@ "AzureDnsZoneNames", "AzureLocation", "AzurePremiumFileShareSize", + "AzureServiceTag", "AzureSdkCredentialScope", "AzureSubscriptionName", "AzureVmSku", diff --git a/data_safe_haven/types/enums.py b/data_safe_haven/types/enums.py index 170cbba4a0..35465f260e 100644 --- a/data_safe_haven/types/enums.py +++ b/data_safe_haven/types/enums.py @@ -26,6 +26,11 @@ class AzureSdkCredentialScope(str, Enum): KEY_VAULT = "https://vault.azure.net" +@verify(UNIQUE) +class AzureServiceTag(str, Enum): + INTERNET = "Internet" + + @verify(UNIQUE) class DatabaseSystem(str, Enum): MICROSOFT_SQL_SERVER = "mssql" diff --git a/data_safe_haven/validators/validators.py b/data_safe_haven/validators/validators.py index dd4458ec57..27507d26b4 100644 --- a/data_safe_haven/validators/validators.py +++ b/data_safe_haven/validators/validators.py @@ -124,7 +124,7 @@ def ip_address(ip_address: str) -> str: try: return str(ipaddress.ip_network(ip_address)) except Exception as exc: - msg = "Expected valid IPv4 address, for example '1.1.1.1'." + msg = "Expected valid IPv4 address, for example '1.1.1.1', or 'Internet'." raise ValueError(msg) from exc diff --git a/tests/config/test_config_sections.py b/tests/config/test_config_sections.py index 0363e41e38..6528b130fa 100644 --- a/tests/config/test_config_sections.py +++ b/tests/config/test_config_sections.py @@ -9,7 +9,11 @@ ConfigSubsectionRemoteDesktopOpts, ConfigSubsectionStorageQuotaGB, ) -from data_safe_haven.types import DatabaseSystem, SoftwarePackageCategory +from data_safe_haven.types import ( + AzureServiceTag, + DatabaseSystem, + SoftwarePackageCategory, +) class TestConfigSectionAzure: @@ -184,6 +188,24 @@ def test_ip_overlap_research_user(self): research_user_ip_addresses=["1.2.3.4", "1.2.3.4"], ) + def test_research_user_tag_internet( + self, + config_subsection_remote_desktop: ConfigSubsectionRemoteDesktopOpts, + config_subsection_storage_quota_gb: ConfigSubsectionStorageQuotaGB, + ): + sre_config = ConfigSectionSRE( + admin_email_address="admin@example.com", + remote_desktop=config_subsection_remote_desktop, + storage_quota_gb=config_subsection_storage_quota_gb, + research_user_ip_addresses="Internet", + ) + assert isinstance(sre_config.research_user_ip_addresses, AzureServiceTag) + assert sre_config.research_user_ip_addresses == "Internet" + + def test_research_user_tag_invalid(self): + with pytest.raises(ValueError, match="Input should be 'Internet'"): + ConfigSectionSRE(research_user_ip_addresses="Not a tag") + @pytest.mark.parametrize( "addresses", [ diff --git a/tests/validators/test_validators.py b/tests/validators/test_validators.py index 1c38e551f8..18d2fd31b5 100644 --- a/tests/validators/test_validators.py +++ b/tests/validators/test_validators.py @@ -86,6 +86,36 @@ def test_fqdn_fail(self, fqdn): validators.fqdn(fqdn) +class TestValidateIpAddress: + @pytest.mark.parametrize( + "ip_address,output", + [ + ("127.0.0.1", "127.0.0.1/32"), + ("0.0.0.0/0", "0.0.0.0/0"), + ("192.168.171.1/32", "192.168.171.1/32"), + ], + ) + def test_ip_address(self, ip_address, output): + assert validators.ip_address(ip_address) == output + + @pytest.mark.parametrize( + "ip_address", + [ + "example.com", + "University of Life", + "999.999.999.999", + "0.0.0.0/-1", + "255.255.255.0/2", + ], + ) + def test_ip_address_fail(self, ip_address): + with pytest.raises( + ValueError, + match="Expected valid IPv4 address, for example '1.1.1.1', or 'Internet'.", + ): + validators.ip_address(ip_address) + + class TestValidateSafeString: @pytest.mark.parametrize( "safe_string",