diff --git a/aiomisc/service/dns/store.py b/aiomisc/service/dns/store.py index d765ed17..bd3d976f 100644 --- a/aiomisc/service/dns/store.py +++ b/aiomisc/service/dns/store.py @@ -1,4 +1,4 @@ -from typing import Optional, Sequence, Tuple +from typing import Iterable, Mapping, Optional, Sequence, Tuple from .records import DNSRecord, RecordType from .tree import RadixTree @@ -49,3 +49,22 @@ def get_zone_for_name(self, name: str) -> Optional[Tuple[str, ...]]: @staticmethod def get_reverse_tuple(zone_name: str) -> Tuple[str, ...]: return tuple(zone_name.strip(".").split("."))[::-1] + + def replace( + self, zones_data: Mapping[str, Iterable[DNSRecord]], + ) -> None: + """ + Atomically replace all zones with new ones this method is safe + because it replaces all zones at once. zone_data is a mapping + zone name and a sequence of DNSRecord objects. + + If any of the zones or records is invalid, nothing will be replaced. + + This method is useful for reload configuration from disk + or database or etc. + """ + new_zones: RadixTree[DNSZone] = RadixTree() + for zone_name, records in zones_data.items(): + zone = DNSZone(zone_name, *records) + new_zones.insert(self.get_reverse_tuple(zone.name), zone) + self.zones = new_zones diff --git a/aiomisc/service/dns/zone.py b/aiomisc/service/dns/zone.py index 60224b46..8ed501c7 100644 --- a/aiomisc/service/dns/zone.py +++ b/aiomisc/service/dns/zone.py @@ -1,28 +1,33 @@ from collections import defaultdict -from typing import DefaultDict, Sequence, Set, Tuple +from typing import DefaultDict, Iterable, Sequence, Set, Tuple from .records import DNSRecord, RecordType +RecordsType = DefaultDict[Tuple[str, RecordType], Set[DNSRecord]] + + class DNSZone: - records: DefaultDict[Tuple[str, RecordType], Set[DNSRecord]] + records: RecordsType name: str __slots__ = ("name", "records") - def __init__(self, name: str): + def __init__(self, name: str, *records: DNSRecord) -> None: if not name.endswith("."): name += "." self.name = name self.records = defaultdict(set) + for record in records: + self.add_record(record) + def add_record(self, record: DNSRecord) -> None: - if not self._is_valid_record(record): + if not self.check_record(record): raise ValueError( f"Record {record.name} does not belong to zone {self.name}", ) - key = (record.name, record.type) - self.records[key].add(record) + self.records[(record.name, record.type)].add(record) def remove_record(self, record: DNSRecord) -> None: key = (record.name, record.type) @@ -37,8 +42,26 @@ def get_records( ) -> Sequence[DNSRecord]: if not name.endswith("."): name += "." - key = (name, record_type) - return tuple(self.records.get(key, ())) + return tuple(self.records.get((name, record_type), ())) - def _is_valid_record(self, record: DNSRecord) -> bool: + def check_record(self, record: DNSRecord) -> bool: return record.name.endswith(self.name) + + def replace(self, records: Iterable[DNSRecord]) -> None: + """ + Atomically replace all records in specified zone with new ones. + This method is safe because it replaces all records at once. + + If any of the records does not belong to the zone, ValueError + will be raised and no records will be replaced. + """ + new_records: RecordsType = defaultdict(set) + + for record in records: + if not self.check_record(record): + raise ValueError( + f"Record {record.name} does not " + f"belong to zone {self.name}", + ) + new_records[(record.name, record.type)].add(record) + self.records = new_records diff --git a/tests/test_dns.py b/tests/test_dns.py index 955c80a1..cbb01427 100644 --- a/tests/test_dns.py +++ b/tests/test_dns.py @@ -516,3 +516,111 @@ def test_sshfp_create(): assert record.data.fp_type == 1 assert record.data.fingerprint == b"abcdefg" assert record.ttl == 300 + + +def test_zone_replace(dns_store): + zone = DNSZone("example.com.") + record1 = A.create(name="www.example.com.", ip="192.0.2.1") + record2 = A.create(name="api.example.com.", ip="192.0.2.2") + zone.add_record(record1) + dns_store.add_zone(zone) + + zone.replace([record2]) + + records = dns_store.query("www.example.com.", RecordType.A) + assert len(records) == 0 + records = dns_store.query("api.example.com.", RecordType.A) + assert len(records) == 1 + assert record2 in records + + +def test_zone_replace_multiple_records(): + zone = DNSZone("example.com.") + record1 = A.create(name="www.example.com.", ip="192.0.2.1") + record2 = A.create(name="www.example.com.", ip="192.0.2.2") + + zone.replace([record1, record2]) + records = zone.get_records("www.example.com.", RecordType.A) + assert len(records) == 2 + assert record1 in records + assert record2 in records + + +def test_zone_replace_empty(): + zone = DNSZone("example.com.") + record = A.create(name="www.example.com.", ip="192.0.2.1") + zone.add_record(record) + + zone.replace([]) + records = zone.get_records("www.example.com.", RecordType.A) + assert len(records) == 0 + + +def test_zone_replace_invalid_record(): + zone = DNSZone("example.com.") + record = A.create(name="www.other.com.", ip="192.0.2.1") + + with pytest.raises(ValueError, match="does not belong to zone"): + zone.replace([record]) + + +def test_store_replace_basic(dns_store): + zone1 = DNSZone("example.com.") + record1 = A.create(name="www.example.com.", ip="192.0.2.1") + zone1.add_record(record1) + dns_store.add_zone(zone1) + + zone2 = DNSZone("test.com.") + record2 = A.create(name="www.test.com.", ip="192.0.2.2") + zone2.add_record(record2) + dns_store.add_zone(zone2) + + # Replace with new data + new_record1 = A.create(name="api.example.com.", ip="192.0.2.3") + new_record2 = A.create(name="api.test.com.", ip="192.0.2.4") + + dns_store.replace({ + "example.com.": [new_record1], + "test.com.": [new_record2], + }) + + # Check old records are gone + records = dns_store.query("www.example.com.", RecordType.A) + assert len(records) == 0 + records = dns_store.query("www.test.com.", RecordType.A) + assert len(records) == 0 + + # Check new records are present + records = dns_store.query("api.example.com.", RecordType.A) + assert len(records) == 1 + assert new_record1 in records + records = dns_store.query("api.test.com.", RecordType.A) + assert len(records) == 1 + assert new_record2 in records + + +def test_store_replace_empty(dns_store): + zone = DNSZone("example.com.") + record = A.create(name="www.example.com.", ip="192.0.2.1") + zone.add_record(record) + dns_store.add_zone(zone) + + dns_store.replace({}) + + assert dns_store.get_zone("example.com.") is None + records = dns_store.query("www.example.com.", RecordType.A) + assert len(records) == 0 + + +def test_store_replace_multiple_records_per_zone(dns_store): + new_record1 = A.create(name="www.example.com.", ip="192.0.2.1") + new_record2 = A.create(name="www.example.com.", ip="192.0.2.2") + + dns_store.replace({ + "example.com.": [new_record1, new_record2], + }) + + records = dns_store.query("www.example.com.", RecordType.A) + assert len(records) == 2 + assert new_record1 in records + assert new_record2 in records