From ffd3a5b268fb70b516aabc71d34040565e87cc88 Mon Sep 17 00:00:00 2001 From: Xnoe Date: Sat, 2 May 2026 15:23:37 +0100 Subject: [PATCH] Rework the codebase to genericise TCP/UDP handling --- src/main.rs | 308 ++++++++++++++++++++++++++++++++-------------------- 1 file changed, 188 insertions(+), 120 deletions(-) diff --git a/src/main.rs b/src/main.rs index 49ec772..327eba4 100644 --- a/src/main.rs +++ b/src/main.rs @@ -6,12 +6,13 @@ mod dns_parser; mod ip_pool; use std::collections::HashMap; -use std::net::Ipv4Addr; +use std::net::{Ipv4Addr, SocketAddr, SocketAddrV4}; use std::process::{Command, Output}; -use std::sync::OnceLock; +use std::sync::{Arc, OnceLock}; use std::time::{Duration, Instant}; use ini::Ini; +use tokio::net::TcpSocket; use tokio::{ io::{AsyncReadExt, AsyncWriteExt}, net::{TcpListener, TcpStream, UdpSocket}, @@ -121,32 +122,7 @@ async fn teardown_forwarding( let _dnat_output = run_command_string(dnat_command).unwrap(); } -async fn query_upstream_resolvers(original_message: &[u8], tcp: bool) -> anyhow::Result> { - let mut upstream_reply = Vec::new(); - - if tcp { - let mut upstream = TcpStream::connect(UPSTREAM_DNS.get().unwrap()).await?; - //upstream.write(&(original_message.len() as u16).to_be_bytes()).await?; - upstream.write(&original_message).await?; - let mut size_buffer = [0u8; 2]; - upstream.read(&mut size_buffer).await?; - let size: u16 = u16::from_be_bytes(size_buffer); - upstream_reply.resize(size as usize + 2, 0u8); - upstream.read(&mut upstream_reply[2..]).await?; - upstream_reply[0] = size_buffer[0]; - upstream_reply[1] = size_buffer[1]; - } else { - let upstream = UdpSocket::bind("0.0.0.0:0").await?; - upstream.connect(UPSTREAM_DNS.get().unwrap()).await?; - upstream.send(original_message).await?; - upstream_reply.resize(512, 0); - upstream.recv_from(&mut upstream_reply[..512]).await?; - } - - return Ok(upstream_reply); -} - -fn forge_replies(replies: &Vec, dns_name_string: String, qname_parts: Vec>, original_message: &[u8], reply_buf: &mut Vec, tcp: bool) { +fn forge_replies(replies: &Vec, dns_name_string: String, qname_parts: Vec>, original_message: &[u8], reply_buf: &mut Vec) { let reply: [u8; 12] = [ 0u8, 0, // ID @@ -162,8 +138,6 @@ fn forge_replies(replies: &Vec, dns_name_string: String, qname_parts 0, // Arcount ]; - let offset = if tcp { 2 } else { 0 }; - let now = Instant::now(); println!( @@ -180,9 +154,9 @@ fn forge_replies(replies: &Vec, dns_name_string: String, qname_parts ); let mut new_reply = reply.clone().to_vec(); - new_reply[0..2].copy_from_slice(&original_message[offset..][0..=1]); + new_reply[0..2].copy_from_slice(&original_message[0..=1]); new_reply[6..8].copy_from_slice(&(replies.len() as u16).to_be_bytes()); - new_reply.extend_from_slice(&original_message[offset..][12..=12 + qname_parts.iter().map(|p| p.len()+1).sum::() + 1 + 3]); + new_reply.extend_from_slice(&original_message[12..=12 + qname_parts.iter().map(|p| p.len()+1).sum::() + 1 + 3]); for reply in replies { new_reply.extend_from_slice(&parts_to_dns_name(&qname_parts)); new_reply.extend_from_slice(&[0, 1]); @@ -191,21 +165,20 @@ fn forge_replies(replies: &Vec, dns_name_string: String, qname_parts new_reply.extend_from_slice(&[0, 4]); new_reply.extend_from_slice(&reply.forged_ip.octets()); } - if tcp { - reply_buf.extend_from_slice(&(new_reply.len() as u16).to_be_bytes()); - } reply_buf.extend_from_slice(&new_reply); } -async fn handle_dns_response(buf: &[u8], reply_buf: &mut Vec, tcp: bool) -> anyhow::Result<()> { +async fn handle_dns_response(mut handler: impl SelectiveRoutingHandler) -> anyhow::Result<()> { + let buf = handler.get_original_message(); + let mut cursor = Cursor::from(buf); - let offset: usize = if tcp { 2 } else { 0 }; - let mut cursor = Cursor::from(&buf[offset..]); + let mut reply_buf = Vec::new(); // Identify some metadata from the query if cursor.seek(4).is_err() { return Err(anyhow::Error::msg("Failed to seek to QDCount")); } + let qdcount = u16::from_be_bytes(cursor.next_array::<2>().ok_or(anyhow::Error::msg("Failed to read QDCount"))?); if qdcount != 1 { eprintln!("Got qdcount: {}", qdcount); @@ -223,8 +196,7 @@ async fn handle_dns_response(buf: &[u8], reply_buf: &mut Vec, tcp: bool) -> // If the query is for anything other than qclass IN or qtype A, just forward the query upstream if qclass != 1 || qtype != 1 { - let upstream_reply = query_upstream_resolvers(buf, tcp).await?; - reply_buf.extend_from_slice(&upstream_reply); + handler.reply(handler.query_upstream().await.or(Err(anyhow::Error::msg("Failed to query upstream resolver.")))?).await.or(Err(anyhow::Error::msg("Failed to query upstream resolver.")))?; return Ok(()) } @@ -240,13 +212,14 @@ async fn handle_dns_response(buf: &[u8], reply_buf: &mut Vec, tcp: bool) -> // Let's first lookup the qname in the Forwardings to see if we have non-expired answers if entries.len() > 0 && entries.iter().all(|e| e.expires > now) { - forge_replies(&entries, dns_name_string, qname_parts, buf, reply_buf, tcp); + forge_replies(&entries, dns_name_string, qname_parts, buf, &mut reply_buf); + handler.reply(reply_buf).await.or(Err(anyhow::Error::msg("Failed to reply!")))?; return Ok(()); } else { - let upstream_reply = query_upstream_resolvers(buf, tcp).await?; + let upstream_reply = handler.query_upstream().await.or(Err(anyhow::Error::msg("Failed to query upstream resolver.")))?; // Try an answer from the upstream response that has type A. - let a_answers = match AnswerIterator::from(&upstream_reply[offset..]) { + let a_answers = match AnswerIterator::from(&upstream_reply) { Some(answers) => answers .filter(|Answer { rrtype, .. }| *rrtype == 1) .collect::>(), @@ -385,11 +358,13 @@ async fn handle_dns_response(buf: &[u8], reply_buf: &mut Vec, tcp: bool) -> replies.extend(entries); } - forge_replies(&replies, dns_name_string, qname_parts, buf, reply_buf, tcp); + forge_replies(&replies, dns_name_string, qname_parts, buf, &mut reply_buf); } else { reply_buf.extend_from_slice(&upstream_reply); } + handler.reply(reply_buf).await.or(Err(anyhow::Error::msg("Failed to reply.")))?; + Ok(()) } } @@ -471,86 +446,136 @@ static COMMAND_PREFIX: OnceLock = OnceLock::new(); static UPSTREAM_DNS: OnceLock = OnceLock::new(); static LISTEN_ADDR: OnceLock = OnceLock::new(); -async fn tcp_handler(listener: TcpListener) { - loop { - let accepted = listener.accept().await; - if let Err(e) = &accepted { - eprintln!("[TCP] Failed to accept TCP socket with error: {:?}", e); +struct Config { + listen_addr: SocketAddrV4, + upstream_resolver: SocketAddrV4 +} + +trait SelectiveRoutingHandler { + async fn query_upstream(&self) -> Result, ()>; + async fn reply(&mut self, message: Vec) -> Result<(), ()>; + + fn get_original_message(&self) -> &[u8]; +} + +struct UDPSelectiveRoutingHandler { + config: Arc, + original_message: Vec, + reply_address: Ipv4Addr, + reply_port: u16, + reply_socket: Arc, +} + +impl UDPSelectiveRoutingHandler { + pub async fn from(config: &Arc, receiving_socket: &Arc) -> Result { + let mut recv_buf = vec![0; 512]; + + let reply_address; + let reply_port; + + match receiving_socket.recv_from(&mut recv_buf).await { + Ok((_, SocketAddr::V4(addr))) => { + reply_address = *addr.ip(); + reply_port = addr.port(); + }, + Ok(_) => unimplemented!(), + Err(_) => { + return Err(()) + } } - let (mut socket, addr) = accepted.unwrap(); - let mut size_buf = [0u8; 2]; - let Ok(2) = socket.read(&mut size_buf).await else { - eprintln!("[TCP] Failed to read message size from {}", addr); - continue; - }; - let message_size = u16::from_be_bytes(size_buf) as usize; - let mut message_buffer = Vec::new(); - message_buffer.resize(message_size + 2, 0); - match socket.read(&mut message_buffer[2..]).await { - Ok(n) => { - if n != message_size { - eprintln!( - "[TCP] Received too few bytes than expected from {}. Message size indicated {}, read {}", - addr, message_size, n - ); - continue; - } + Ok( + Self { + config: Arc::clone(config), + original_message: recv_buf, + reply_address, + reply_port, + reply_socket: Arc::clone(receiving_socket) } - Err(e) => { - eprintln!( - "[TCP] Failed to read message from {} with error: {:?}", - addr, e - ); - continue; - } - } - message_buffer[0] = size_buf[0]; - message_buffer[1] = size_buf[1]; - let mut reply = Vec::new(); - if let Err(e) = handle_dns_response(&message_buffer, &mut reply, true).await { - eprintln!( - "[TCP] Received error when handling response for {}: {:?}", - addr, e - ); - continue; - } - if let Err(e) = socket.write(&reply).await { - eprintln!( - "[TCP] Received error when sending response to {}: {:?}", - addr, e - ); - continue; - } + ) } } -async fn udp_handler(socket: UdpSocket) { - loop { - let mut message_buffer = [0u8; 512]; - match socket.recv_from(&mut message_buffer).await { - Ok((_, addr)) => { - let mut reply = Vec::new(); - if let Err(e) = handle_dns_response(&message_buffer, &mut reply, false).await { - eprintln!( - "[UDP] Received error when handling response for {}: {:?}", - addr, e - ); - continue; - } - if let Err(e) = socket.send_to(&reply, addr).await { - eprintln!( - "[UDP] Received error when sending response to {}: {:?}", - addr, e - ); - continue; - } +impl SelectiveRoutingHandler for UDPSelectiveRoutingHandler { + async fn query_upstream(&self) -> Result, ()> { + let query_upstream_socket = UdpSocket::bind("0.0.0.0:0").await.or(Err(()))?; + + query_upstream_socket.send_to(&self.original_message, self.config.upstream_resolver).await.or(Err(()))?; + + let mut recv_buf = vec![0; 512]; + + query_upstream_socket.recv(&mut recv_buf).await.or(Err(()))?; + + Ok(recv_buf) + } + + async fn reply(&mut self, message: Vec) -> Result<(), ()> { + let _ = self.reply_socket.send_to(&message, (self.reply_address, self.reply_port)).await; + + Ok(()) + } + + fn get_original_message(&self) -> &[u8] { + &self.original_message + } +} + +struct TCPSelectiveRoutingHandler { + config: Arc, + original_message: Vec, + reply_stream: TcpStream +} + +impl TCPSelectiveRoutingHandler { + async fn from(config: &Arc, mut receiving_stream: TcpStream) -> Result { + let mut message_size_buf: [u8; 2] = [0, 0]; + receiving_stream.read(&mut message_size_buf).await.or(Err(()))?; + + let message_size: u16 = u16::from_be_bytes(message_size_buf); + + let mut message = vec![0u8; message_size as usize]; + + receiving_stream.read(&mut message).await.or(Err(()))?; + + Ok( + Self { + config: Arc::clone(config), + original_message: message, + reply_stream: receiving_stream } - Err(e) => { - eprintln!("[UDP] Failed to read message with error: {:?}", e); - continue; - } - } + ) + } +} + +impl SelectiveRoutingHandler for TCPSelectiveRoutingHandler { + async fn query_upstream(&self) -> Result, ()> { + let upstream_socket = TcpSocket::new_v4().or(Err(()))?; + let mut upstream_stream = upstream_socket.connect(self.config.upstream_resolver.into()).await.or(Err(()))?; + + upstream_stream.write(&(self.original_message.len() as u16).to_be_bytes()).await.or(Err(()))?; + upstream_stream.write(&self.original_message).await.or(Err(()))?; + + let mut message_size_buf: [u8; 2] = [0, 0]; + upstream_stream.read(&mut message_size_buf).await.or(Err(()))?; + + let message_size: u16 = u16::from_be_bytes(message_size_buf); + + let mut message = vec![0; message_size as usize]; + + upstream_stream.read(&mut message).await.or(Err(()))?; + + Ok(message) + } + + async fn reply(&mut self, message: Vec) -> Result<(), ()> { + self.reply_stream.write(&(message.len() as u16).to_be_bytes()).await.or(Err(()))?; + self.reply_stream.write(&message).await.or(Err(()))?; + + Ok(()) + } + + fn get_original_message(&self) -> &[u8] { + &self.original_message } } @@ -607,11 +632,54 @@ async fn main() -> anyhow::Result<()> { } let recv_socket_tcp = TcpListener::bind(LISTEN_ADDR.get().unwrap()).await?; - let recv_socket_udp = UdpSocket::bind(LISTEN_ADDR.get().unwrap()).await?; + let recv_socket_udp = Arc::new(UdpSocket::bind(LISTEN_ADDR.get().unwrap()).await?); + + let config = Arc::new(Config { + listen_addr: LISTEN_ADDR.get().unwrap().parse::().unwrap(), + upstream_resolver: UPSTREAM_DNS.get().unwrap().parse::().unwrap() + }); tokio::try_join!( - tokio::spawn(tcp_handler(recv_socket_tcp)), - tokio::spawn(udp_handler(recv_socket_udp)), + tokio::spawn({ + let config = Arc::clone(&config); + let recv_socket_udp = Arc::clone(&recv_socket_udp); + async move { + loop { + match UDPSelectiveRoutingHandler::from(&config, &recv_socket_udp).await { + Ok(handler) => { + tokio::spawn(async { + if let Err(e) = handle_dns_response(handler).await { + eprintln!("{}", e) + } + }); + }, + + Err(_) => () + } + } + } + }), + tokio::spawn({ + let config = Arc::clone(&config); + async move { + loop { + match recv_socket_tcp.accept().await { + Ok((stream, _)) => { + eprintln!("[TCP] Accepted connection from {}", stream.peer_addr().unwrap()); + if let Ok(handler) = TCPSelectiveRoutingHandler::from(&config, stream).await { + tokio::spawn(async { + if let Err(e) = handle_dns_response(handler).await { + eprintln!("{}", e) + } + }); + } + }, + + Err(_) => () + } + } + } + }), )?; Ok(()) }