From d0b5f8c7be980afb861b72fa4cf81889b00e3bd0 Mon Sep 17 00:00:00 2001 From: Xnoe Date: Thu, 14 May 2026 23:55:59 +0100 Subject: [PATCH] cargo fmt --- src/cursor.rs | 7 +- src/main.rs | 210 +++++++++++++++++++++++++++++++------------------- 2 files changed, 136 insertions(+), 81 deletions(-) diff --git a/src/cursor.rs b/src/cursor.rs index 52cc2cf..0f8e366 100644 --- a/src/cursor.rs +++ b/src/cursor.rs @@ -44,7 +44,10 @@ impl<'a, T> Cursor<'a, T> { } } - pub fn next_array(&mut self) -> Option<[T; N]> where [T; N]: TryFrom<&'a [T]> { + pub fn next_array(&mut self) -> Option<[T; N]> + where + [T; N]: TryFrom<&'a [T]>, + { Some(self.next_slice(N)?.try_into().ok()?) } @@ -57,4 +60,4 @@ impl<'a, T> Cursor<'a, T> { Ok(()) } } -} \ No newline at end of file +} diff --git a/src/main.rs b/src/main.rs index f6933d9..92072d2 100644 --- a/src/main.rs +++ b/src/main.rs @@ -21,8 +21,8 @@ use tokio::{ use crate::cursor::Cursor; use crate::dns_parser::{ - Answer, AnswerIterator, dns_name_to_parts, dns_parts_to_string, - parts_to_dns_name, string_to_dns_name, + Answer, AnswerIterator, dns_name_to_parts, dns_parts_to_string, parts_to_dns_name, + string_to_dns_name, }; use crate::ip_pool::IpPool; @@ -122,7 +122,12 @@ async fn teardown_forwarding( let _dnat_output = run_command_string(dnat_command).unwrap(); } -fn forge_replies(replies: &Vec, dns_name_string: String, qname_parts: Vec>, original_message: &[u8]) -> Vec { +fn forge_replies( + replies: &Vec, + dns_name_string: String, + qname_parts: Vec>, + original_message: &[u8], +) -> Vec { let reply: [u8; 12] = [ 0u8, 0, // ID @@ -156,7 +161,9 @@ fn forge_replies(replies: &Vec, dns_name_string: String, qname_parts let mut new_reply = reply.clone().to_vec(); new_reply[0..2].copy_from_slice(&original_message[0..=1]); new_reply[6..8].copy_from_slice(&(replies.len() as u16).to_be_bytes()); - new_reply.extend_from_slice(&original_message[12..=12 + qname_parts.iter().map(|p| p.len()+1).sum::() + 1 + 3]); + new_reply.extend_from_slice( + &original_message[12..=12 + qname_parts.iter().map(|p| p.len() + 1).sum::() + 1 + 3], + ); for reply in replies { new_reply.extend_from_slice(&parts_to_dns_name(&qname_parts)); new_reply.extend_from_slice(&[0, 1]); @@ -165,7 +172,7 @@ fn forge_replies(replies: &Vec, dns_name_string: String, qname_parts new_reply.extend_from_slice(&[0, 4]); new_reply.extend_from_slice(&reply.forged_ip.octets()); } - + new_reply } @@ -178,7 +185,11 @@ async fn handle_dns_response(mut handler: impl SelectiveRoutingHandler) -> anyho return Err(anyhow::Error::msg("Failed to seek to QDCount")); } - let qdcount = u16::from_be_bytes(cursor.next_array::<2>().ok_or(anyhow::Error::msg("Failed to read QDCount"))?); + let qdcount = u16::from_be_bytes( + cursor + .next_array::<2>() + .ok_or(anyhow::Error::msg("Failed to read QDCount"))?, + ); if qdcount != 1 { eprintln!("Got qdcount: {}", qdcount); return Err(anyhow::Error::msg(format!( @@ -191,13 +202,28 @@ async fn handle_dns_response(mut handler: impl SelectiveRoutingHandler) -> anyho } let qname_parts = dns_name_to_parts(&mut cursor).ok_or(anyhow::Error::msg("Failed to decode QName."))?; - let qtype = u16::from_be_bytes(cursor.next_array::<2>().ok_or(anyhow::Error::msg("Failed to read QType."))?); - let qclass = u16::from_be_bytes(cursor.next_array::<2>().ok_or(anyhow::Error::msg("Failed to read QClass."))?); + let qtype = u16::from_be_bytes( + cursor + .next_array::<2>() + .ok_or(anyhow::Error::msg("Failed to read QType."))?, + ); + let qclass = u16::from_be_bytes( + cursor + .next_array::<2>() + .ok_or(anyhow::Error::msg("Failed to read QClass."))?, + ); // If the query is for anything other than qclass IN or qtype A, just forward the query upstream if qclass != 1 || qtype != 1 { - handler.reply(handler.query_upstream().await.or(Err(anyhow::Error::msg("Failed to query upstream resolver.")))?).await.or(Err(anyhow::Error::msg("Failed to query upstream resolver.")))?; - return Ok(()) + handler + .reply(handler.query_upstream().await.or(Err(anyhow::Error::msg( + "Failed to query upstream resolver.", + )))?) + .await + .or(Err(anyhow::Error::msg( + "Failed to query upstream resolver.", + )))?; + return Ok(()); } let now = Instant::now(); @@ -212,13 +238,16 @@ async fn handle_dns_response(mut handler: impl SelectiveRoutingHandler) -> anyho // Let's first lookup the qname in the Forwardings to see if we have non-expired answers if entries.len() > 0 && entries.iter().all(|e| e.expires > now) { - handler.reply( - forge_replies(&entries, dns_name_string, qname_parts, buf) - ).await.or(Err(anyhow::Error::msg("Failed to reply!")))?; + handler + .reply(forge_replies(&entries, dns_name_string, qname_parts, buf)) + .await + .or(Err(anyhow::Error::msg("Failed to reply!")))?; return Ok(()); } - let upstream_reply = handler.query_upstream().await.or(Err(anyhow::Error::msg("Failed to query upstream resolver.")))?; + let upstream_reply = handler.query_upstream().await.or(Err(anyhow::Error::msg( + "Failed to query upstream resolver.", + )))?; // Try an answer from the upstream response that has type A. let answer_iterator = AnswerIterator::from(&upstream_reply).or(Err(anyhow::Error::msg( @@ -257,8 +286,11 @@ async fn handle_dns_response(mut handler: impl SelectiveRoutingHandler) -> anyho } if !should_forge { - handler.reply(upstream_reply).await.or(Err(anyhow::Error::msg("Failed to reply.")))?; - return Ok(()) + handler + .reply(upstream_reply) + .await + .or(Err(anyhow::Error::msg("Failed to reply.")))?; + return Ok(()); } // Normalise a_answers so we're working with Ipv4Addr @@ -268,9 +300,7 @@ async fn handle_dns_response(mut handler: impl SelectiveRoutingHandler) -> anyho ( answer.name, answer.ttl, - Ipv4Addr::from( - as TryInto<[u8; 4]>>::try_into(answer.rdata).unwrap(), - ), + Ipv4Addr::from( as TryInto<[u8; 4]>>::try_into(answer.rdata).unwrap()), ) }) .collect::<_>(); @@ -360,7 +390,10 @@ async fn handle_dns_response(mut handler: impl SelectiveRoutingHandler) -> anyho replies.extend(entries); } - handler.reply(forge_replies(&replies, dns_name_string, qname_parts, buf)).await.or(Err(anyhow::Error::msg("Failed to reply.")))?; + handler + .reply(forge_replies(&replies, dns_name_string, qname_parts, buf)) + .await + .or(Err(anyhow::Error::msg("Failed to reply.")))?; Ok(()) } @@ -375,38 +408,31 @@ struct Forwarding { enum FwmarkDomainMatchType { Exact, - Suffix + Suffix, } struct FwmarkConfig { - entries: Vec<(FwmarkDomainMatchType, Vec>, u32)> + entries: Vec<(FwmarkDomainMatchType, Vec>, u32)>, } impl FwmarkConfig { fn new() -> Self { Self { - entries: Vec::new() + entries: Vec::new(), } } fn insert(&mut self, dns_name: Vec, fwmark: u32) { - self.entries.push( - ( - FwmarkDomainMatchType::Exact, - dns_name_to_parts(&mut Cursor::from(dns_name.as_slice())).unwrap(), - fwmark - ) - ); + self.entries.push(( + FwmarkDomainMatchType::Exact, + dns_name_to_parts(&mut Cursor::from(dns_name.as_slice())).unwrap(), + fwmark, + )); } fn insert_wildcard(&mut self, suffix_parts: Vec>, fwmark: u32) { - self.entries.push( - ( - FwmarkDomainMatchType::Suffix, - suffix_parts, - fwmark - ) - ); + self.entries + .push((FwmarkDomainMatchType::Suffix, suffix_parts, fwmark)); } fn get(&self, dns_name: &Vec) -> Option { @@ -416,12 +442,12 @@ impl FwmarkConfig { match match_type { FwmarkDomainMatchType::Exact => { if parts == *_parts { - return Some(*fwmark) + return Some(*fwmark); } - }, + } FwmarkDomainMatchType::Suffix => { if parts.ends_with(_parts.as_slice()) { - return Some(*fwmark) + return Some(*fwmark); } } } @@ -444,7 +470,7 @@ static LISTEN_ADDR: OnceLock = OnceLock::new(); struct Config { listen_addr: SocketAddrV4, - upstream_resolver: SocketAddrV4 + upstream_resolver: SocketAddrV4, } trait SelectiveRoutingHandler { @@ -473,22 +499,18 @@ impl UDPSelectiveRoutingHandler { Ok((_, SocketAddr::V4(addr))) => { reply_address = *addr.ip(); reply_port = addr.port(); - }, - Ok(_) => unimplemented!(), - Err(_) => { - return Err(()) } + Ok(_) => unimplemented!(), + Err(_) => return Err(()), } - Ok( - Self { - config: Arc::clone(config), - original_message: recv_buf, - reply_address, - reply_port, - reply_socket: Arc::clone(receiving_socket) - } - ) + Ok(Self { + config: Arc::clone(config), + original_message: recv_buf, + reply_address, + reply_port, + reply_socket: Arc::clone(receiving_socket), + }) } } @@ -496,18 +518,27 @@ impl SelectiveRoutingHandler for UDPSelectiveRoutingHandler { async fn query_upstream(&self) -> Result, ()> { let query_upstream_socket = UdpSocket::bind("0.0.0.0:0").await.or(Err(()))?; - query_upstream_socket.send_to(&self.original_message, self.config.upstream_resolver).await.or(Err(()))?; + query_upstream_socket + .send_to(&self.original_message, self.config.upstream_resolver) + .await + .or(Err(()))?; let mut recv_buf = vec![0; 512]; - query_upstream_socket.recv(&mut recv_buf).await.or(Err(()))?; + query_upstream_socket + .recv(&mut recv_buf) + .await + .or(Err(()))?; Ok(recv_buf) } async fn reply(&mut self, message: Vec) -> Result<(), ()> { - let _ = self.reply_socket.send_to(&message, (self.reply_address, self.reply_port)).await; - + let _ = self + .reply_socket + .send_to(&message, (self.reply_address, self.reply_port)) + .await; + Ok(()) } @@ -519,13 +550,16 @@ impl SelectiveRoutingHandler for UDPSelectiveRoutingHandler { struct TCPSelectiveRoutingHandler { config: Arc, original_message: Vec, - reply_stream: TcpStream + reply_stream: TcpStream, } impl TCPSelectiveRoutingHandler { async fn from(config: &Arc, mut receiving_stream: TcpStream) -> Result { let mut message_size_buf: [u8; 2] = [0, 0]; - receiving_stream.read(&mut message_size_buf).await.or(Err(()))?; + receiving_stream + .read(&mut message_size_buf) + .await + .or(Err(()))?; let message_size: u16 = u16::from_be_bytes(message_size_buf); @@ -533,26 +567,36 @@ impl TCPSelectiveRoutingHandler { receiving_stream.read(&mut message).await.or(Err(()))?; - Ok( - Self { - config: Arc::clone(config), - original_message: message, - reply_stream: receiving_stream - } - ) + Ok(Self { + config: Arc::clone(config), + original_message: message, + reply_stream: receiving_stream, + }) } } impl SelectiveRoutingHandler for TCPSelectiveRoutingHandler { async fn query_upstream(&self) -> Result, ()> { let upstream_socket = TcpSocket::new_v4().or(Err(()))?; - let mut upstream_stream = upstream_socket.connect(self.config.upstream_resolver.into()).await.or(Err(()))?; + let mut upstream_stream = upstream_socket + .connect(self.config.upstream_resolver.into()) + .await + .or(Err(()))?; - upstream_stream.write(&(self.original_message.len() as u16).to_be_bytes()).await.or(Err(()))?; - upstream_stream.write(&self.original_message).await.or(Err(()))?; + upstream_stream + .write(&(self.original_message.len() as u16).to_be_bytes()) + .await + .or(Err(()))?; + upstream_stream + .write(&self.original_message) + .await + .or(Err(()))?; let mut message_size_buf: [u8; 2] = [0, 0]; - upstream_stream.read(&mut message_size_buf).await.or(Err(()))?; + upstream_stream + .read(&mut message_size_buf) + .await + .or(Err(()))?; let message_size: u16 = u16::from_be_bytes(message_size_buf); @@ -564,7 +608,10 @@ impl SelectiveRoutingHandler for TCPSelectiveRoutingHandler { } async fn reply(&mut self, message: Vec) -> Result<(), ()> { - self.reply_stream.write(&(message.len() as u16).to_be_bytes()).await.or(Err(()))?; + self.reply_stream + .write(&(message.len() as u16).to_be_bytes()) + .await + .or(Err(()))?; self.reply_stream.write(&message).await.or(Err(()))?; Ok(()) @@ -632,7 +679,7 @@ async fn main() -> anyhow::Result<()> { let config = Arc::new(Config { listen_addr: LISTEN_ADDR.get().unwrap().parse::().unwrap(), - upstream_resolver: UPSTREAM_DNS.get().unwrap().parse::().unwrap() + upstream_resolver: UPSTREAM_DNS.get().unwrap().parse::().unwrap(), }); tokio::try_join!( @@ -648,9 +695,9 @@ async fn main() -> anyhow::Result<()> { eprintln!("{}", e) } }); - }, + } - Err(_) => () + Err(_) => (), } } } @@ -661,17 +708,22 @@ async fn main() -> anyhow::Result<()> { loop { match recv_socket_tcp.accept().await { Ok((stream, _)) => { - eprintln!("[TCP] Accepted connection from {}", stream.peer_addr().unwrap()); - if let Ok(handler) = TCPSelectiveRoutingHandler::from(&config, stream).await { + eprintln!( + "[TCP] Accepted connection from {}", + stream.peer_addr().unwrap() + ); + if let Ok(handler) = + TCPSelectiveRoutingHandler::from(&config, stream).await + { tokio::spawn(async { if let Err(e) = handle_dns_response(handler).await { eprintln!("{}", e) } }); } - }, + } - Err(_) => () + Err(_) => (), } } }