Skip to content

Commit

Permalink
add more tests
Browse files Browse the repository at this point in the history
  • Loading branch information
mosquito committed Jun 4, 2024
1 parent a6f07ee commit 1a5db6d
Show file tree
Hide file tree
Showing 2 changed files with 100 additions and 11 deletions.
3 changes: 3 additions & 0 deletions aiomisc/service/dns/service.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
import asyncio
import logging
from asyncio import StreamReader, StreamWriter
from struct import Struct
Expand Down Expand Up @@ -123,6 +124,8 @@ async def handle_client(
writer.write(TCP_HEADER_STRUCT.pack(len(reply_body)))
writer.write(reply_body)
await writer.drain()
except asyncio.IncompleteReadError:
pass
finally:
writer.close()
await writer.wait_closed()
108 changes: 97 additions & 11 deletions tests/test_dns_server.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,8 +4,9 @@
import pytest

from aiomisc import threaded
from aiomisc.service.dns import DNSStore, DNSZone, UDPDNSServer
from aiomisc.service.dns import DNSStore, DNSZone, TCPDNSServer, UDPDNSServer
from aiomisc.service.dns.records import AAAA, CNAME, A, RecordType
from aiomisc.service.dns.service import TCP_HEADER_STRUCT


@pytest.fixture
Expand Down Expand Up @@ -34,26 +35,53 @@ def dns_server_port(aiomisc_unused_port_factory) -> int:


@pytest.fixture
def dns_server(dns_store_filled: DNSStore, dns_server_port) -> UDPDNSServer:
def dns_server_udp(
dns_store_filled: DNSStore, dns_server_port,
) -> UDPDNSServer:
server = UDPDNSServer(
store=dns_store_filled, address="localhost", port=dns_server_port,
)
return server


@pytest.fixture
def services(dns_server):
return [dns_server]
def dns_server_tcp(
dns_store_filled: DNSStore, dns_server_port,
) -> TCPDNSServer:
server = TCPDNSServer(
store=dns_store_filled, address="localhost", port=dns_server_port,
)
return server


@pytest.fixture
def services(dns_server_udp, dns_server_tcp):
return [dns_server_udp, dns_server_tcp]


@threaded
def dns_send_receive(data, port):
def dns_send_receive_udp(data, port):
# Send the query
with socket.socket(socket.AF_INET, socket.SOCK_DGRAM) as sock:
sock.settimeout(5)
sock.sendto(data, ("localhost", port))
# Receive the response
response_data, _ = sock.recvfrom(512)
response_data, _ = sock.recvfrom(65535)
return dnslib.DNSRecord.parse(response_data)


@threaded
def dns_send_receive_tcp(data, port):
# Send the query
with socket.socket(socket.AF_INET, socket.SOCK_STREAM) as sock:
sock.settimeout(5)
sock.connect(("localhost", port))
sock.sendall(data)
# Receive the response length (2 bytes)
length_data = sock.recv(2)
length = int.from_bytes(length_data, byteorder="big")
# Receive the actual response
response_data = sock.recv(length)
return dnslib.DNSRecord.parse(response_data)


Expand All @@ -62,7 +90,7 @@ async def test_handle_datagram_a_record(services, dns_server_port):
query = dnslib.DNSRecord.question("sub.example.com.", qtype="A")
query_data = query.pack()

response = await dns_send_receive(query_data, dns_server_port)
response = await dns_send_receive_udp(query_data, dns_server_port)

# Verify the response
assert response.header.rcode == dnslib.RCODE.NOERROR
Expand All @@ -78,7 +106,7 @@ async def test_handle_datagram_aaaa_record(services, dns_server_port):
)
query_data = query.pack()

response = await dns_send_receive(query_data, dns_server_port)
response = await dns_send_receive_udp(query_data, dns_server_port)

# Verify the response
assert response.header.rcode == dnslib.RCODE.NOERROR
Expand All @@ -94,7 +122,7 @@ async def test_handle_datagram_cname_record(services, dns_server_port):
)
query_data = query.pack()

response = await dns_send_receive(query_data, dns_server_port)
response = await dns_send_receive_udp(query_data, dns_server_port)

# Verify the response
assert response.header.rcode == dnslib.RCODE.NOERROR
Expand All @@ -110,7 +138,7 @@ async def test_handle_datagram_nonexistent_record(services, dns_server_port):
)
query_data = query.pack()

response = await dns_send_receive(query_data, dns_server_port)
response = await dns_send_receive_udp(query_data, dns_server_port)

# Verify the response
assert response.header.rcode == dnslib.RCODE.NXDOMAIN
Expand All @@ -130,8 +158,66 @@ async def test_handle_datagram_remove_record(
query = dnslib.DNSRecord.question("sub.example.com.", qtype="A")
query_data = query.pack()

response = await dns_send_receive(query_data, dns_server_port)
response = await dns_send_receive_udp(query_data, dns_server_port)

# Verify the response
assert response.header.rcode == dnslib.RCODE.NXDOMAIN
assert len(response.rr) == 0


async def test_handle_datagram_edns_record(services, dns_server_port):
# Prepare a DNS query with EDNS0
query = dnslib.DNSRecord.question("sub.example.com.", qtype="A")
query.add_ar(dnslib.EDNS0(udp_len=4096))
query_data = query.pack()

response = await dns_send_receive_udp(query_data, dns_server_port)

# Verify the response
assert response.header.rcode == dnslib.RCODE.NOERROR
assert len(response.rr) == 1
assert str(response.rr[0].rname) == "sub.example.com."
assert response.rr[0].rdata == dnslib.A("192.0.2.1")
assert response.ar[0].rtype == dnslib.QTYPE.OPT


async def test_handle_large_datagram_truncated_udp(
services, dns_server_port, dns_store_filled,
):
# Add many A records to the store to exceed typical UDP packet size
zone = dns_store_filled.get_zone("example.com.")
for i in range(1, 101):
a_record = A.create(name="rr.example.com.", ip=f"192.0.2.{i}")
zone.add_record(a_record)

# Prepare a DNS query
query = dnslib.DNSRecord.question("rr.example.com.", qtype="A")
query_data = query.pack()

response = await dns_send_receive_udp(query_data, dns_server_port)

# Verify the response is truncated
assert response.header.tc == 1 # Ensure the TC (truncated) bit is set


async def test_handle_large_tcp_request(
services, dns_server_port, dns_store_filled,
):
# Add many A records to the store to exceed typical UDP packet size
zone = dns_store_filled.get_zone("example.com.")
for i in range(1, 101):
a_record = A.create(
name="rr.example.com.", ip=f"192.0.2.{i}", ttl=3600,
)
zone.add_record(a_record)

# Prepare a DNS query for TCP
query = dnslib.DNSRecord.question("rr.example.com.", qtype="A")
query_data = TCP_HEADER_STRUCT.pack(len(query.pack())) + query.pack()

response = await dns_send_receive_tcp(query_data, dns_server_port)

# Verify the TCP response is not truncated and contains all records
assert response.header.rcode == dnslib.RCODE.NOERROR
assert len(response.rr) == 100
assert all(str(rr.rname).startswith("rr") for rr in response.rr)

0 comments on commit 1a5db6d

Please sign in to comment.