Rework the codebase to genericise TCP/UDP handling

This commit is contained in:
2026-05-02 15:23:37 +01:00
parent 8e7220f704
commit ffd3a5b268
+188 -120
View File
@@ -6,12 +6,13 @@ mod dns_parser;
mod ip_pool;
use std::collections::HashMap;
use std::net::Ipv4Addr;
use std::net::{Ipv4Addr, SocketAddr, SocketAddrV4};
use std::process::{Command, Output};
use std::sync::OnceLock;
use std::sync::{Arc, OnceLock};
use std::time::{Duration, Instant};
use ini::Ini;
use tokio::net::TcpSocket;
use tokio::{
io::{AsyncReadExt, AsyncWriteExt},
net::{TcpListener, TcpStream, UdpSocket},
@@ -121,32 +122,7 @@ async fn teardown_forwarding(
let _dnat_output = run_command_string(dnat_command).unwrap();
}
async fn query_upstream_resolvers(original_message: &[u8], tcp: bool) -> anyhow::Result<Vec<u8>> {
let mut upstream_reply = Vec::new();
if tcp {
let mut upstream = TcpStream::connect(UPSTREAM_DNS.get().unwrap()).await?;
//upstream.write(&(original_message.len() as u16).to_be_bytes()).await?;
upstream.write(&original_message).await?;
let mut size_buffer = [0u8; 2];
upstream.read(&mut size_buffer).await?;
let size: u16 = u16::from_be_bytes(size_buffer);
upstream_reply.resize(size as usize + 2, 0u8);
upstream.read(&mut upstream_reply[2..]).await?;
upstream_reply[0] = size_buffer[0];
upstream_reply[1] = size_buffer[1];
} else {
let upstream = UdpSocket::bind("0.0.0.0:0").await?;
upstream.connect(UPSTREAM_DNS.get().unwrap()).await?;
upstream.send(original_message).await?;
upstream_reply.resize(512, 0);
upstream.recv_from(&mut upstream_reply[..512]).await?;
}
return Ok(upstream_reply);
}
fn forge_replies(replies: &Vec<Forwarding>, dns_name_string: String, qname_parts: Vec<Vec<u8>>, original_message: &[u8], reply_buf: &mut Vec<u8>, tcp: bool) {
fn forge_replies(replies: &Vec<Forwarding>, dns_name_string: String, qname_parts: Vec<Vec<u8>>, original_message: &[u8], reply_buf: &mut Vec<u8>) {
let reply: [u8; 12] = [
0u8,
0, // ID
@@ -162,8 +138,6 @@ fn forge_replies(replies: &Vec<Forwarding>, dns_name_string: String, qname_parts
0, // Arcount
];
let offset = if tcp { 2 } else { 0 };
let now = Instant::now();
println!(
@@ -180,9 +154,9 @@ fn forge_replies(replies: &Vec<Forwarding>, dns_name_string: String, qname_parts
);
let mut new_reply = reply.clone().to_vec();
new_reply[0..2].copy_from_slice(&original_message[offset..][0..=1]);
new_reply[0..2].copy_from_slice(&original_message[0..=1]);
new_reply[6..8].copy_from_slice(&(replies.len() as u16).to_be_bytes());
new_reply.extend_from_slice(&original_message[offset..][12..=12 + qname_parts.iter().map(|p| p.len()+1).sum::<usize>() + 1 + 3]);
new_reply.extend_from_slice(&original_message[12..=12 + qname_parts.iter().map(|p| p.len()+1).sum::<usize>() + 1 + 3]);
for reply in replies {
new_reply.extend_from_slice(&parts_to_dns_name(&qname_parts));
new_reply.extend_from_slice(&[0, 1]);
@@ -191,21 +165,20 @@ fn forge_replies(replies: &Vec<Forwarding>, dns_name_string: String, qname_parts
new_reply.extend_from_slice(&[0, 4]);
new_reply.extend_from_slice(&reply.forged_ip.octets());
}
if tcp {
reply_buf.extend_from_slice(&(new_reply.len() as u16).to_be_bytes());
}
reply_buf.extend_from_slice(&new_reply);
}
async fn handle_dns_response(buf: &[u8], reply_buf: &mut Vec<u8>, tcp: bool) -> anyhow::Result<()> {
async fn handle_dns_response(mut handler: impl SelectiveRoutingHandler) -> anyhow::Result<()> {
let buf = handler.get_original_message();
let mut cursor = Cursor::from(buf);
let offset: usize = if tcp { 2 } else { 0 };
let mut cursor = Cursor::from(&buf[offset..]);
let mut reply_buf = Vec::new();
// Identify some metadata from the query
if cursor.seek(4).is_err() {
return Err(anyhow::Error::msg("Failed to seek to QDCount"));
}
let qdcount = u16::from_be_bytes(cursor.next_array::<2>().ok_or(anyhow::Error::msg("Failed to read QDCount"))?);
if qdcount != 1 {
eprintln!("Got qdcount: {}", qdcount);
@@ -223,8 +196,7 @@ async fn handle_dns_response(buf: &[u8], reply_buf: &mut Vec<u8>, tcp: bool) ->
// If the query is for anything other than qclass IN or qtype A, just forward the query upstream
if qclass != 1 || qtype != 1 {
let upstream_reply = query_upstream_resolvers(buf, tcp).await?;
reply_buf.extend_from_slice(&upstream_reply);
handler.reply(handler.query_upstream().await.or(Err(anyhow::Error::msg("Failed to query upstream resolver.")))?).await.or(Err(anyhow::Error::msg("Failed to query upstream resolver.")))?;
return Ok(())
}
@@ -240,13 +212,14 @@ async fn handle_dns_response(buf: &[u8], reply_buf: &mut Vec<u8>, tcp: bool) ->
// Let's first lookup the qname in the Forwardings to see if we have non-expired answers
if entries.len() > 0 && entries.iter().all(|e| e.expires > now) {
forge_replies(&entries, dns_name_string, qname_parts, buf, reply_buf, tcp);
forge_replies(&entries, dns_name_string, qname_parts, buf, &mut reply_buf);
handler.reply(reply_buf).await.or(Err(anyhow::Error::msg("Failed to reply!")))?;
return Ok(());
} else {
let upstream_reply = query_upstream_resolvers(buf, tcp).await?;
let upstream_reply = handler.query_upstream().await.or(Err(anyhow::Error::msg("Failed to query upstream resolver.")))?;
// Try an answer from the upstream response that has type A.
let a_answers = match AnswerIterator::from(&upstream_reply[offset..]) {
let a_answers = match AnswerIterator::from(&upstream_reply) {
Some(answers) => answers
.filter(|Answer { rrtype, .. }| *rrtype == 1)
.collect::<Vec<Answer>>(),
@@ -385,11 +358,13 @@ async fn handle_dns_response(buf: &[u8], reply_buf: &mut Vec<u8>, tcp: bool) ->
replies.extend(entries);
}
forge_replies(&replies, dns_name_string, qname_parts, buf, reply_buf, tcp);
forge_replies(&replies, dns_name_string, qname_parts, buf, &mut reply_buf);
} else {
reply_buf.extend_from_slice(&upstream_reply);
}
handler.reply(reply_buf).await.or(Err(anyhow::Error::msg("Failed to reply.")))?;
Ok(())
}
}
@@ -471,86 +446,136 @@ static COMMAND_PREFIX: OnceLock<String> = OnceLock::new();
static UPSTREAM_DNS: OnceLock<String> = OnceLock::new();
static LISTEN_ADDR: OnceLock<String> = OnceLock::new();
async fn tcp_handler(listener: TcpListener) {
loop {
let accepted = listener.accept().await;
if let Err(e) = &accepted {
eprintln!("[TCP] Failed to accept TCP socket with error: {:?}", e);
struct Config {
listen_addr: SocketAddrV4,
upstream_resolver: SocketAddrV4
}
trait SelectiveRoutingHandler {
async fn query_upstream(&self) -> Result<Vec<u8>, ()>;
async fn reply(&mut self, message: Vec<u8>) -> Result<(), ()>;
fn get_original_message(&self) -> &[u8];
}
struct UDPSelectiveRoutingHandler {
config: Arc<Config>,
original_message: Vec<u8>,
reply_address: Ipv4Addr,
reply_port: u16,
reply_socket: Arc<UdpSocket>,
}
impl UDPSelectiveRoutingHandler {
pub async fn from(config: &Arc<Config>, receiving_socket: &Arc<UdpSocket>) -> Result<Self, ()> {
let mut recv_buf = vec![0; 512];
let reply_address;
let reply_port;
match receiving_socket.recv_from(&mut recv_buf).await {
Ok((_, SocketAddr::V4(addr))) => {
reply_address = *addr.ip();
reply_port = addr.port();
},
Ok(_) => unimplemented!(),
Err(_) => {
return Err(())
}
}
let (mut socket, addr) = accepted.unwrap();
let mut size_buf = [0u8; 2];
let Ok(2) = socket.read(&mut size_buf).await else {
eprintln!("[TCP] Failed to read message size from {}", addr);
continue;
};
let message_size = u16::from_be_bytes(size_buf) as usize;
let mut message_buffer = Vec::new();
message_buffer.resize(message_size + 2, 0);
match socket.read(&mut message_buffer[2..]).await {
Ok(n) => {
if n != message_size {
eprintln!(
"[TCP] Received too few bytes than expected from {}. Message size indicated {}, read {}",
addr, message_size, n
);
continue;
}
Ok(
Self {
config: Arc::clone(config),
original_message: recv_buf,
reply_address,
reply_port,
reply_socket: Arc::clone(receiving_socket)
}
Err(e) => {
eprintln!(
"[TCP] Failed to read message from {} with error: {:?}",
addr, e
);
continue;
}
}
message_buffer[0] = size_buf[0];
message_buffer[1] = size_buf[1];
let mut reply = Vec::new();
if let Err(e) = handle_dns_response(&message_buffer, &mut reply, true).await {
eprintln!(
"[TCP] Received error when handling response for {}: {:?}",
addr, e
);
continue;
}
if let Err(e) = socket.write(&reply).await {
eprintln!(
"[TCP] Received error when sending response to {}: {:?}",
addr, e
);
continue;
}
)
}
}
async fn udp_handler(socket: UdpSocket) {
loop {
let mut message_buffer = [0u8; 512];
match socket.recv_from(&mut message_buffer).await {
Ok((_, addr)) => {
let mut reply = Vec::new();
if let Err(e) = handle_dns_response(&message_buffer, &mut reply, false).await {
eprintln!(
"[UDP] Received error when handling response for {}: {:?}",
addr, e
);
continue;
}
if let Err(e) = socket.send_to(&reply, addr).await {
eprintln!(
"[UDP] Received error when sending response to {}: {:?}",
addr, e
);
continue;
}
impl SelectiveRoutingHandler for UDPSelectiveRoutingHandler {
async fn query_upstream(&self) -> Result<Vec<u8>, ()> {
let query_upstream_socket = UdpSocket::bind("0.0.0.0:0").await.or(Err(()))?;
query_upstream_socket.send_to(&self.original_message, self.config.upstream_resolver).await.or(Err(()))?;
let mut recv_buf = vec![0; 512];
query_upstream_socket.recv(&mut recv_buf).await.or(Err(()))?;
Ok(recv_buf)
}
async fn reply(&mut self, message: Vec<u8>) -> Result<(), ()> {
let _ = self.reply_socket.send_to(&message, (self.reply_address, self.reply_port)).await;
Ok(())
}
fn get_original_message(&self) -> &[u8] {
&self.original_message
}
}
struct TCPSelectiveRoutingHandler {
config: Arc<Config>,
original_message: Vec<u8>,
reply_stream: TcpStream
}
impl TCPSelectiveRoutingHandler {
async fn from(config: &Arc<Config>, mut receiving_stream: TcpStream) -> Result<Self, ()> {
let mut message_size_buf: [u8; 2] = [0, 0];
receiving_stream.read(&mut message_size_buf).await.or(Err(()))?;
let message_size: u16 = u16::from_be_bytes(message_size_buf);
let mut message = vec![0u8; message_size as usize];
receiving_stream.read(&mut message).await.or(Err(()))?;
Ok(
Self {
config: Arc::clone(config),
original_message: message,
reply_stream: receiving_stream
}
Err(e) => {
eprintln!("[UDP] Failed to read message with error: {:?}", e);
continue;
}
}
)
}
}
impl SelectiveRoutingHandler for TCPSelectiveRoutingHandler {
async fn query_upstream(&self) -> Result<Vec<u8>, ()> {
let upstream_socket = TcpSocket::new_v4().or(Err(()))?;
let mut upstream_stream = upstream_socket.connect(self.config.upstream_resolver.into()).await.or(Err(()))?;
upstream_stream.write(&(self.original_message.len() as u16).to_be_bytes()).await.or(Err(()))?;
upstream_stream.write(&self.original_message).await.or(Err(()))?;
let mut message_size_buf: [u8; 2] = [0, 0];
upstream_stream.read(&mut message_size_buf).await.or(Err(()))?;
let message_size: u16 = u16::from_be_bytes(message_size_buf);
let mut message = vec![0; message_size as usize];
upstream_stream.read(&mut message).await.or(Err(()))?;
Ok(message)
}
async fn reply(&mut self, message: Vec<u8>) -> Result<(), ()> {
self.reply_stream.write(&(message.len() as u16).to_be_bytes()).await.or(Err(()))?;
self.reply_stream.write(&message).await.or(Err(()))?;
Ok(())
}
fn get_original_message(&self) -> &[u8] {
&self.original_message
}
}
@@ -607,11 +632,54 @@ async fn main() -> anyhow::Result<()> {
}
let recv_socket_tcp = TcpListener::bind(LISTEN_ADDR.get().unwrap()).await?;
let recv_socket_udp = UdpSocket::bind(LISTEN_ADDR.get().unwrap()).await?;
let recv_socket_udp = Arc::new(UdpSocket::bind(LISTEN_ADDR.get().unwrap()).await?);
let config = Arc::new(Config {
listen_addr: LISTEN_ADDR.get().unwrap().parse::<SocketAddrV4>().unwrap(),
upstream_resolver: UPSTREAM_DNS.get().unwrap().parse::<SocketAddrV4>().unwrap()
});
tokio::try_join!(
tokio::spawn(tcp_handler(recv_socket_tcp)),
tokio::spawn(udp_handler(recv_socket_udp)),
tokio::spawn({
let config = Arc::clone(&config);
let recv_socket_udp = Arc::clone(&recv_socket_udp);
async move {
loop {
match UDPSelectiveRoutingHandler::from(&config, &recv_socket_udp).await {
Ok(handler) => {
tokio::spawn(async {
if let Err(e) = handle_dns_response(handler).await {
eprintln!("{}", e)
}
});
},
Err(_) => ()
}
}
}
}),
tokio::spawn({
let config = Arc::clone(&config);
async move {
loop {
match recv_socket_tcp.accept().await {
Ok((stream, _)) => {
eprintln!("[TCP] Accepted connection from {}", stream.peer_addr().unwrap());
if let Ok(handler) = TCPSelectiveRoutingHandler::from(&config, stream).await {
tokio::spawn(async {
if let Err(e) = handle_dns_response(handler).await {
eprintln!("{}", e)
}
});
}
},
Err(_) => ()
}
}
}
}),
)?;
Ok(())
}