Skip to content

Commit

Permalink
class-based resolver
Browse files Browse the repository at this point in the history
  • Loading branch information
mhils committed Jul 2, 2024
1 parent 558b20d commit 79e9ef5
Show file tree
Hide file tree
Showing 10 changed files with 286 additions and 120 deletions.
61 changes: 61 additions & 0 deletions Cargo.lock

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

1 change: 1 addition & 0 deletions Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -81,6 +81,7 @@ sysinfo = "0.29.10"
env_logger = "0.11"
rand = "0.8"
criterion = "0.5.1"
hickory-server = "0.24.1"


[[bench]]
Expand Down
26 changes: 26 additions & 0 deletions benches/dns.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,26 @@
import mitmproxy_rs
import asyncio
import socket

async def main():
builder = mitmproxy_rs.DnsResolverBuilder()
builder.use_hosts_file(False)
builder.use_nameserver(["8.8.8.8"])
resolver = builder.build()

async def lookup(host: str):
try:
r = await resolver.lookup_ip(host)
except socket.gaierror as e:
print(f"{host=} {e=}")
else:
print(f"{host=} {r=}")

await lookup("example.com.")
await lookup("nxdomain.mitmproxy.org.")
await lookup("no-a-records.mitmproxy.org.")

print(f"{mitmproxy_rs.get_system_dns_servers()=}")


asyncio.run(main())
9 changes: 0 additions & 9 deletions dns.py

This file was deleted.

1 change: 1 addition & 0 deletions mitmproxy-rs/Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,7 @@ boringtun = "0.6"
tar = "0.4.41"
console-subscriber = { version = "0.3.0", optional = true }


[dev-dependencies]
env_logger = "0.11"

Expand Down
28 changes: 0 additions & 28 deletions mitmproxy-rs/dns.py

This file was deleted.

13 changes: 11 additions & 2 deletions mitmproxy-rs/mitmproxy_rs/__init__.pyi
Original file line number Diff line number Diff line change
Expand Up @@ -124,7 +124,16 @@ class Process:
@property
def is_system(self) -> bool: ...

# DNS resolver
# DNS resolution

@final
class DnsResolver:
async def lookup_ip(self, host: str) -> list[str]: ...

@final
class DnsResolverBuilder:
def use_hosts_file(self, value: bool) -> None: ...
def use_name_servers(self, value: list[str]) -> None: ...
def build(self) -> DnsResolver: ...

async def getaddrinfo(host: str, family: int = 0, use_hosts_file: bool = True): list[str]
def get_system_dns_servers(): list[str]
100 changes: 62 additions & 38 deletions mitmproxy-rs/src/dns_resolver.rs
Original file line number Diff line number Diff line change
@@ -1,54 +1,78 @@
use pyo3::types::PyAny;
use pyo3::prelude::*;
use mitmproxy::dns::{
NameServerConfig, Protocol, ResolveErrorKind, ResolverConfig, ResponseCode, DNS_SERVERS,
};
use pyo3::exceptions::socket::gaierror;
use pyo3::prelude::*;
use pyo3::types::PyAny;
use std::{net::IpAddr, net::SocketAddr, sync::Arc};

use mitmproxy::dns::{LookupIpStrategy, ResolveErrorKind, ResponseCode, DNS_SERVERS};
#[pyclass]
pub struct DnsResolverBuilder(mitmproxy::dns::DnsResolverBuilder);

#[pymethods]
impl DnsResolverBuilder {
#[new]
fn new() -> Self {
Self(mitmproxy::dns::DnsResolverBuilder::default())
}

#[pyclass]
#[derive(Copy, Clone)]
pub enum AddressFamily {
Ipv6Only,
Ipv4Only,
DualStack,
}
fn use_hosts_file(&mut self, value: bool) {
self.0.use_hosts_file(value);
}

impl From<AddressFamily> for LookupIpStrategy {
fn from(value: AddressFamily) -> Self {
match value {
AddressFamily::DualStack => LookupIpStrategy::Ipv4AndIpv6,
AddressFamily::Ipv4Only => LookupIpStrategy::Ipv4Only,
AddressFamily::Ipv6Only => LookupIpStrategy::Ipv6Only,
fn use_name_servers(&mut self, value: Vec<IpAddr>) {
let mut conf = ResolverConfig::new();
for ip in value.into_iter() {
let addr = SocketAddr::from((ip, 53));
conf.add_name_server(NameServerConfig::new(addr, Protocol::Udp));
conf.add_name_server(NameServerConfig::new(addr, Protocol::Tcp));
}
self.0.use_config(conf);
}

fn build(&self) -> PyResult<DnsResolver> {
let inner = self.0.build().map_err(|e| {
pyo3::exceptions::PyRuntimeError::new_err(format!(
"failed to build dns resolver: {}",
e
))
})?;
Ok(DnsResolver(Arc::new(inner)))
}
}

#[pyclass]
pub struct DnsResolver(Arc<mitmproxy::dns::DnsResolver>);

#[pyfunction]
#[pyo3(signature = (host, family, use_hosts_file=true))]
pub fn getaddrinfo(py: Python<'_>, host: String, family: AddressFamily, use_hosts_file: bool) -> PyResult<Bound<PyAny>> {
pyo3_asyncio_0_21::tokio::future_into_py(py, async move {
match mitmproxy::dns::getaddrinfo(host, family.into(), use_hosts_file).await {
Ok(resp) => {
Ok(resp.into_iter().map(|ip| ip.to_string()).collect::<Vec<String>>())
},
Err(e) => match *e.kind() {
ResolveErrorKind::NoRecordsFound { response_code: ResponseCode::NXDomain, .. } => {
Err(gaierror::new_err("NXDOMAIN"))
}
ResolveErrorKind::NoRecordsFound { response_code: ResponseCode::NoError, .. } => {
Err(gaierror::new_err("NOERROR"))
}
#[pymethods]
impl DnsResolver {
pub fn lookup_ip<'py>(&self, py: Python<'py>, host: String) -> PyResult<Bound<'py, PyAny>> {
let resolver = self.0.clone();
pyo3_asyncio_0_21::tokio::future_into_py(py, async move {
match resolver.lookup_ip(host).await {
Ok(resp) => Ok(resp
.into_iter()
.map(|ip| ip.to_string())
.collect::<Vec<String>>()),
Err(e) => match *e.kind() {
ResolveErrorKind::NoRecordsFound {
response_code: ResponseCode::NXDomain,
..
} => Err(gaierror::new_err("NXDOMAIN")),
ResolveErrorKind::NoRecordsFound {
response_code: ResponseCode::NoError,
..
} => Err(gaierror::new_err("NOERROR")),
_ => Err(gaierror::new_err(e.to_string())),
}

}
})
},
}
})
}
}

#[pyfunction]
pub fn get_system_dns_servers() -> PyResult<Vec<String>> {
DNS_SERVERS
.clone()
.map_err(|e| pyo3::exceptions::PyRuntimeError::new_err(format!("{}", e)))
DNS_SERVERS.clone().map_err(|e| {
pyo3::exceptions::PyRuntimeError::new_err(format!("failed to get dns servers: {}", e))
})
}
4 changes: 2 additions & 2 deletions mitmproxy-rs/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -60,8 +60,8 @@ pub fn mitmproxy_rs(py: Python, m: &Bound<'_, PyModule>) -> PyResult<()> {
m.add_class::<process_info::Process>()?;
m.add_function(wrap_pyfunction!(process_info::executable_icon, m)?)?;

m.add_class::<dns_resolver::AddressFamily>()?;
m.add_function(wrap_pyfunction!(dns_resolver::getaddrinfo, m)?)?;
m.add_class::<dns_resolver::DnsResolverBuilder>()?;
m.add_class::<dns_resolver::DnsResolver>()?;
m.add_function(wrap_pyfunction!(dns_resolver::get_system_dns_servers, m)?)?;

m.add_class::<stream::Stream>()?;
Expand Down
Loading

0 comments on commit 79e9ef5

Please sign in to comment.