Skip to content

Commit

Permalink
Safe Replace DNSStore content and DNSZone content (#219)
Browse files Browse the repository at this point in the history
  • Loading branch information
mosquito authored Dec 31, 2024
1 parent e990289 commit 25ac87f
Show file tree
Hide file tree
Showing 3 changed files with 160 additions and 10 deletions.
21 changes: 20 additions & 1 deletion aiomisc/service/dns/store.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
from typing import Optional, Sequence, Tuple
from typing import Iterable, Mapping, Optional, Sequence, Tuple

from .records import DNSRecord, RecordType
from .tree import RadixTree
Expand Down Expand Up @@ -49,3 +49,22 @@ def get_zone_for_name(self, name: str) -> Optional[Tuple[str, ...]]:
@staticmethod
def get_reverse_tuple(zone_name: str) -> Tuple[str, ...]:
return tuple(zone_name.strip(".").split("."))[::-1]

def replace(
self, zones_data: Mapping[str, Iterable[DNSRecord]],
) -> None:
"""
Atomically replace all zones with new ones this method is safe
because it replaces all zones at once. zone_data is a mapping
zone name and a sequence of DNSRecord objects.
If any of the zones or records is invalid, nothing will be replaced.
This method is useful for reload configuration from disk
or database or etc.
"""
new_zones: RadixTree[DNSZone] = RadixTree()
for zone_name, records in zones_data.items():
zone = DNSZone(zone_name, *records)
new_zones.insert(self.get_reverse_tuple(zone.name), zone)
self.zones = new_zones
41 changes: 32 additions & 9 deletions aiomisc/service/dns/zone.py
Original file line number Diff line number Diff line change
@@ -1,28 +1,33 @@
from collections import defaultdict
from typing import DefaultDict, Sequence, Set, Tuple
from typing import DefaultDict, Iterable, Sequence, Set, Tuple

from .records import DNSRecord, RecordType


RecordsType = DefaultDict[Tuple[str, RecordType], Set[DNSRecord]]


class DNSZone:
records: DefaultDict[Tuple[str, RecordType], Set[DNSRecord]]
records: RecordsType
name: str

__slots__ = ("name", "records")

def __init__(self, name: str):
def __init__(self, name: str, *records: DNSRecord) -> None:
if not name.endswith("."):
name += "."
self.name = name
self.records = defaultdict(set)

for record in records:
self.add_record(record)

def add_record(self, record: DNSRecord) -> None:
if not self._is_valid_record(record):
if not self.check_record(record):
raise ValueError(
f"Record {record.name} does not belong to zone {self.name}",
)
key = (record.name, record.type)
self.records[key].add(record)
self.records[(record.name, record.type)].add(record)

def remove_record(self, record: DNSRecord) -> None:
key = (record.name, record.type)
Expand All @@ -37,8 +42,26 @@ def get_records(
) -> Sequence[DNSRecord]:
if not name.endswith("."):
name += "."
key = (name, record_type)
return tuple(self.records.get(key, ()))
return tuple(self.records.get((name, record_type), ()))

def _is_valid_record(self, record: DNSRecord) -> bool:
def check_record(self, record: DNSRecord) -> bool:
return record.name.endswith(self.name)

def replace(self, records: Iterable[DNSRecord]) -> None:
"""
Atomically replace all records in specified zone with new ones.
This method is safe because it replaces all records at once.
If any of the records does not belong to the zone, ValueError
will be raised and no records will be replaced.
"""
new_records: RecordsType = defaultdict(set)

for record in records:
if not self.check_record(record):
raise ValueError(
f"Record {record.name} does not "
f"belong to zone {self.name}",
)
new_records[(record.name, record.type)].add(record)
self.records = new_records
108 changes: 108 additions & 0 deletions tests/test_dns.py
Original file line number Diff line number Diff line change
Expand Up @@ -516,3 +516,111 @@ def test_sshfp_create():
assert record.data.fp_type == 1
assert record.data.fingerprint == b"abcdefg"
assert record.ttl == 300


def test_zone_replace(dns_store):
zone = DNSZone("example.com.")
record1 = A.create(name="www.example.com.", ip="192.0.2.1")
record2 = A.create(name="api.example.com.", ip="192.0.2.2")
zone.add_record(record1)
dns_store.add_zone(zone)

zone.replace([record2])

records = dns_store.query("www.example.com.", RecordType.A)
assert len(records) == 0
records = dns_store.query("api.example.com.", RecordType.A)
assert len(records) == 1
assert record2 in records


def test_zone_replace_multiple_records():
zone = DNSZone("example.com.")
record1 = A.create(name="www.example.com.", ip="192.0.2.1")
record2 = A.create(name="www.example.com.", ip="192.0.2.2")

zone.replace([record1, record2])
records = zone.get_records("www.example.com.", RecordType.A)
assert len(records) == 2
assert record1 in records
assert record2 in records


def test_zone_replace_empty():
zone = DNSZone("example.com.")
record = A.create(name="www.example.com.", ip="192.0.2.1")
zone.add_record(record)

zone.replace([])
records = zone.get_records("www.example.com.", RecordType.A)
assert len(records) == 0


def test_zone_replace_invalid_record():
zone = DNSZone("example.com.")
record = A.create(name="www.other.com.", ip="192.0.2.1")

with pytest.raises(ValueError, match="does not belong to zone"):
zone.replace([record])


def test_store_replace_basic(dns_store):
zone1 = DNSZone("example.com.")
record1 = A.create(name="www.example.com.", ip="192.0.2.1")
zone1.add_record(record1)
dns_store.add_zone(zone1)

zone2 = DNSZone("test.com.")
record2 = A.create(name="www.test.com.", ip="192.0.2.2")
zone2.add_record(record2)
dns_store.add_zone(zone2)

# Replace with new data
new_record1 = A.create(name="api.example.com.", ip="192.0.2.3")
new_record2 = A.create(name="api.test.com.", ip="192.0.2.4")

dns_store.replace({
"example.com.": [new_record1],
"test.com.": [new_record2],
})

# Check old records are gone
records = dns_store.query("www.example.com.", RecordType.A)
assert len(records) == 0
records = dns_store.query("www.test.com.", RecordType.A)
assert len(records) == 0

# Check new records are present
records = dns_store.query("api.example.com.", RecordType.A)
assert len(records) == 1
assert new_record1 in records
records = dns_store.query("api.test.com.", RecordType.A)
assert len(records) == 1
assert new_record2 in records


def test_store_replace_empty(dns_store):
zone = DNSZone("example.com.")
record = A.create(name="www.example.com.", ip="192.0.2.1")
zone.add_record(record)
dns_store.add_zone(zone)

dns_store.replace({})

assert dns_store.get_zone("example.com.") is None
records = dns_store.query("www.example.com.", RecordType.A)
assert len(records) == 0


def test_store_replace_multiple_records_per_zone(dns_store):
new_record1 = A.create(name="www.example.com.", ip="192.0.2.1")
new_record2 = A.create(name="www.example.com.", ip="192.0.2.2")

dns_store.replace({
"example.com.": [new_record1, new_record2],
})

records = dns_store.query("www.example.com.", RecordType.A)
assert len(records) == 2
assert new_record1 in records
assert new_record2 in records

0 comments on commit 25ac87f

Please sign in to comment.