cargo fmt

This commit is contained in:
2026-05-14 23:55:59 +01:00
parent 7d73832913
commit d0b5f8c7be
2 changed files with 136 additions and 81 deletions
+5 -2
View File
@@ -44,7 +44,10 @@ impl<'a, T> Cursor<'a, T> {
} }
} }
pub fn next_array<const N: usize>(&mut self) -> Option<[T; N]> where [T; N]: TryFrom<&'a [T]> { pub fn next_array<const N: usize>(&mut self) -> Option<[T; N]>
where
[T; N]: TryFrom<&'a [T]>,
{
Some(self.next_slice(N)?.try_into().ok()?) Some(self.next_slice(N)?.try_into().ok()?)
} }
@@ -57,4 +60,4 @@ impl<'a, T> Cursor<'a, T> {
Ok(()) Ok(())
} }
} }
} }
+131 -79
View File
@@ -21,8 +21,8 @@ use tokio::{
use crate::cursor::Cursor; use crate::cursor::Cursor;
use crate::dns_parser::{ use crate::dns_parser::{
Answer, AnswerIterator, dns_name_to_parts, dns_parts_to_string, Answer, AnswerIterator, dns_name_to_parts, dns_parts_to_string, parts_to_dns_name,
parts_to_dns_name, string_to_dns_name, string_to_dns_name,
}; };
use crate::ip_pool::IpPool; use crate::ip_pool::IpPool;
@@ -122,7 +122,12 @@ async fn teardown_forwarding(
let _dnat_output = run_command_string(dnat_command).unwrap(); let _dnat_output = run_command_string(dnat_command).unwrap();
} }
fn forge_replies(replies: &Vec<Forwarding>, dns_name_string: String, qname_parts: Vec<Vec<u8>>, original_message: &[u8]) -> Vec<u8> { fn forge_replies(
replies: &Vec<Forwarding>,
dns_name_string: String,
qname_parts: Vec<Vec<u8>>,
original_message: &[u8],
) -> Vec<u8> {
let reply: [u8; 12] = [ let reply: [u8; 12] = [
0u8, 0u8,
0, // ID 0, // ID
@@ -156,7 +161,9 @@ fn forge_replies(replies: &Vec<Forwarding>, dns_name_string: String, qname_parts
let mut new_reply = reply.clone().to_vec(); let mut new_reply = reply.clone().to_vec();
new_reply[0..2].copy_from_slice(&original_message[0..=1]); new_reply[0..2].copy_from_slice(&original_message[0..=1]);
new_reply[6..8].copy_from_slice(&(replies.len() as u16).to_be_bytes()); new_reply[6..8].copy_from_slice(&(replies.len() as u16).to_be_bytes());
new_reply.extend_from_slice(&original_message[12..=12 + qname_parts.iter().map(|p| p.len()+1).sum::<usize>() + 1 + 3]); new_reply.extend_from_slice(
&original_message[12..=12 + qname_parts.iter().map(|p| p.len() + 1).sum::<usize>() + 1 + 3],
);
for reply in replies { for reply in replies {
new_reply.extend_from_slice(&parts_to_dns_name(&qname_parts)); new_reply.extend_from_slice(&parts_to_dns_name(&qname_parts));
new_reply.extend_from_slice(&[0, 1]); new_reply.extend_from_slice(&[0, 1]);
@@ -165,7 +172,7 @@ fn forge_replies(replies: &Vec<Forwarding>, dns_name_string: String, qname_parts
new_reply.extend_from_slice(&[0, 4]); new_reply.extend_from_slice(&[0, 4]);
new_reply.extend_from_slice(&reply.forged_ip.octets()); new_reply.extend_from_slice(&reply.forged_ip.octets());
} }
new_reply new_reply
} }
@@ -178,7 +185,11 @@ async fn handle_dns_response(mut handler: impl SelectiveRoutingHandler) -> anyho
return Err(anyhow::Error::msg("Failed to seek to QDCount")); return Err(anyhow::Error::msg("Failed to seek to QDCount"));
} }
let qdcount = u16::from_be_bytes(cursor.next_array::<2>().ok_or(anyhow::Error::msg("Failed to read QDCount"))?); let qdcount = u16::from_be_bytes(
cursor
.next_array::<2>()
.ok_or(anyhow::Error::msg("Failed to read QDCount"))?,
);
if qdcount != 1 { if qdcount != 1 {
eprintln!("Got qdcount: {}", qdcount); eprintln!("Got qdcount: {}", qdcount);
return Err(anyhow::Error::msg(format!( return Err(anyhow::Error::msg(format!(
@@ -191,13 +202,28 @@ async fn handle_dns_response(mut handler: impl SelectiveRoutingHandler) -> anyho
} }
let qname_parts = let qname_parts =
dns_name_to_parts(&mut cursor).ok_or(anyhow::Error::msg("Failed to decode QName."))?; dns_name_to_parts(&mut cursor).ok_or(anyhow::Error::msg("Failed to decode QName."))?;
let qtype = u16::from_be_bytes(cursor.next_array::<2>().ok_or(anyhow::Error::msg("Failed to read QType."))?); let qtype = u16::from_be_bytes(
let qclass = u16::from_be_bytes(cursor.next_array::<2>().ok_or(anyhow::Error::msg("Failed to read QClass."))?); cursor
.next_array::<2>()
.ok_or(anyhow::Error::msg("Failed to read QType."))?,
);
let qclass = u16::from_be_bytes(
cursor
.next_array::<2>()
.ok_or(anyhow::Error::msg("Failed to read QClass."))?,
);
// If the query is for anything other than qclass IN or qtype A, just forward the query upstream // If the query is for anything other than qclass IN or qtype A, just forward the query upstream
if qclass != 1 || qtype != 1 { if qclass != 1 || qtype != 1 {
handler.reply(handler.query_upstream().await.or(Err(anyhow::Error::msg("Failed to query upstream resolver.")))?).await.or(Err(anyhow::Error::msg("Failed to query upstream resolver.")))?; handler
return Ok(()) .reply(handler.query_upstream().await.or(Err(anyhow::Error::msg(
"Failed to query upstream resolver.",
)))?)
.await
.or(Err(anyhow::Error::msg(
"Failed to query upstream resolver.",
)))?;
return Ok(());
} }
let now = Instant::now(); let now = Instant::now();
@@ -212,13 +238,16 @@ async fn handle_dns_response(mut handler: impl SelectiveRoutingHandler) -> anyho
// Let's first lookup the qname in the Forwardings to see if we have non-expired answers // Let's first lookup the qname in the Forwardings to see if we have non-expired answers
if entries.len() > 0 && entries.iter().all(|e| e.expires > now) { if entries.len() > 0 && entries.iter().all(|e| e.expires > now) {
handler.reply( handler
forge_replies(&entries, dns_name_string, qname_parts, buf) .reply(forge_replies(&entries, dns_name_string, qname_parts, buf))
).await.or(Err(anyhow::Error::msg("Failed to reply!")))?; .await
.or(Err(anyhow::Error::msg("Failed to reply!")))?;
return Ok(()); return Ok(());
} }
let upstream_reply = handler.query_upstream().await.or(Err(anyhow::Error::msg("Failed to query upstream resolver.")))?; let upstream_reply = handler.query_upstream().await.or(Err(anyhow::Error::msg(
"Failed to query upstream resolver.",
)))?;
// Try an answer from the upstream response that has type A. // Try an answer from the upstream response that has type A.
let answer_iterator = AnswerIterator::from(&upstream_reply).or(Err(anyhow::Error::msg( let answer_iterator = AnswerIterator::from(&upstream_reply).or(Err(anyhow::Error::msg(
@@ -257,8 +286,11 @@ async fn handle_dns_response(mut handler: impl SelectiveRoutingHandler) -> anyho
} }
if !should_forge { if !should_forge {
handler.reply(upstream_reply).await.or(Err(anyhow::Error::msg("Failed to reply.")))?; handler
return Ok(()) .reply(upstream_reply)
.await
.or(Err(anyhow::Error::msg("Failed to reply.")))?;
return Ok(());
} }
// Normalise a_answers so we're working with Ipv4Addr // Normalise a_answers so we're working with Ipv4Addr
@@ -268,9 +300,7 @@ async fn handle_dns_response(mut handler: impl SelectiveRoutingHandler) -> anyho
( (
answer.name, answer.name,
answer.ttl, answer.ttl,
Ipv4Addr::from( Ipv4Addr::from(<Vec<u8> as TryInto<[u8; 4]>>::try_into(answer.rdata).unwrap()),
<Vec<u8> as TryInto<[u8; 4]>>::try_into(answer.rdata).unwrap(),
),
) )
}) })
.collect::<_>(); .collect::<_>();
@@ -360,7 +390,10 @@ async fn handle_dns_response(mut handler: impl SelectiveRoutingHandler) -> anyho
replies.extend(entries); replies.extend(entries);
} }
handler.reply(forge_replies(&replies, dns_name_string, qname_parts, buf)).await.or(Err(anyhow::Error::msg("Failed to reply.")))?; handler
.reply(forge_replies(&replies, dns_name_string, qname_parts, buf))
.await
.or(Err(anyhow::Error::msg("Failed to reply.")))?;
Ok(()) Ok(())
} }
@@ -375,38 +408,31 @@ struct Forwarding {
enum FwmarkDomainMatchType { enum FwmarkDomainMatchType {
Exact, Exact,
Suffix Suffix,
} }
struct FwmarkConfig { struct FwmarkConfig {
entries: Vec<(FwmarkDomainMatchType, Vec<Vec<u8>>, u32)> entries: Vec<(FwmarkDomainMatchType, Vec<Vec<u8>>, u32)>,
} }
impl FwmarkConfig { impl FwmarkConfig {
fn new() -> Self { fn new() -> Self {
Self { Self {
entries: Vec::new() entries: Vec::new(),
} }
} }
fn insert(&mut self, dns_name: Vec<u8>, fwmark: u32) { fn insert(&mut self, dns_name: Vec<u8>, fwmark: u32) {
self.entries.push( self.entries.push((
( FwmarkDomainMatchType::Exact,
FwmarkDomainMatchType::Exact, dns_name_to_parts(&mut Cursor::from(dns_name.as_slice())).unwrap(),
dns_name_to_parts(&mut Cursor::from(dns_name.as_slice())).unwrap(), fwmark,
fwmark ));
)
);
} }
fn insert_wildcard(&mut self, suffix_parts: Vec<Vec<u8>>, fwmark: u32) { fn insert_wildcard(&mut self, suffix_parts: Vec<Vec<u8>>, fwmark: u32) {
self.entries.push( self.entries
( .push((FwmarkDomainMatchType::Suffix, suffix_parts, fwmark));
FwmarkDomainMatchType::Suffix,
suffix_parts,
fwmark
)
);
} }
fn get(&self, dns_name: &Vec<u8>) -> Option<u32> { fn get(&self, dns_name: &Vec<u8>) -> Option<u32> {
@@ -416,12 +442,12 @@ impl FwmarkConfig {
match match_type { match match_type {
FwmarkDomainMatchType::Exact => { FwmarkDomainMatchType::Exact => {
if parts == *_parts { if parts == *_parts {
return Some(*fwmark) return Some(*fwmark);
} }
}, }
FwmarkDomainMatchType::Suffix => { FwmarkDomainMatchType::Suffix => {
if parts.ends_with(_parts.as_slice()) { if parts.ends_with(_parts.as_slice()) {
return Some(*fwmark) return Some(*fwmark);
} }
} }
} }
@@ -444,7 +470,7 @@ static LISTEN_ADDR: OnceLock<String> = OnceLock::new();
struct Config { struct Config {
listen_addr: SocketAddrV4, listen_addr: SocketAddrV4,
upstream_resolver: SocketAddrV4 upstream_resolver: SocketAddrV4,
} }
trait SelectiveRoutingHandler { trait SelectiveRoutingHandler {
@@ -473,22 +499,18 @@ impl UDPSelectiveRoutingHandler {
Ok((_, SocketAddr::V4(addr))) => { Ok((_, SocketAddr::V4(addr))) => {
reply_address = *addr.ip(); reply_address = *addr.ip();
reply_port = addr.port(); reply_port = addr.port();
},
Ok(_) => unimplemented!(),
Err(_) => {
return Err(())
} }
Ok(_) => unimplemented!(),
Err(_) => return Err(()),
} }
Ok( Ok(Self {
Self { config: Arc::clone(config),
config: Arc::clone(config), original_message: recv_buf,
original_message: recv_buf, reply_address,
reply_address, reply_port,
reply_port, reply_socket: Arc::clone(receiving_socket),
reply_socket: Arc::clone(receiving_socket) })
}
)
} }
} }
@@ -496,18 +518,27 @@ impl SelectiveRoutingHandler for UDPSelectiveRoutingHandler {
async fn query_upstream(&self) -> Result<Vec<u8>, ()> { async fn query_upstream(&self) -> Result<Vec<u8>, ()> {
let query_upstream_socket = UdpSocket::bind("0.0.0.0:0").await.or(Err(()))?; let query_upstream_socket = UdpSocket::bind("0.0.0.0:0").await.or(Err(()))?;
query_upstream_socket.send_to(&self.original_message, self.config.upstream_resolver).await.or(Err(()))?; query_upstream_socket
.send_to(&self.original_message, self.config.upstream_resolver)
.await
.or(Err(()))?;
let mut recv_buf = vec![0; 512]; let mut recv_buf = vec![0; 512];
query_upstream_socket.recv(&mut recv_buf).await.or(Err(()))?; query_upstream_socket
.recv(&mut recv_buf)
.await
.or(Err(()))?;
Ok(recv_buf) Ok(recv_buf)
} }
async fn reply(&mut self, message: Vec<u8>) -> Result<(), ()> { async fn reply(&mut self, message: Vec<u8>) -> Result<(), ()> {
let _ = self.reply_socket.send_to(&message, (self.reply_address, self.reply_port)).await; let _ = self
.reply_socket
.send_to(&message, (self.reply_address, self.reply_port))
.await;
Ok(()) Ok(())
} }
@@ -519,13 +550,16 @@ impl SelectiveRoutingHandler for UDPSelectiveRoutingHandler {
struct TCPSelectiveRoutingHandler { struct TCPSelectiveRoutingHandler {
config: Arc<Config>, config: Arc<Config>,
original_message: Vec<u8>, original_message: Vec<u8>,
reply_stream: TcpStream reply_stream: TcpStream,
} }
impl TCPSelectiveRoutingHandler { impl TCPSelectiveRoutingHandler {
async fn from(config: &Arc<Config>, mut receiving_stream: TcpStream) -> Result<Self, ()> { async fn from(config: &Arc<Config>, mut receiving_stream: TcpStream) -> Result<Self, ()> {
let mut message_size_buf: [u8; 2] = [0, 0]; let mut message_size_buf: [u8; 2] = [0, 0];
receiving_stream.read(&mut message_size_buf).await.or(Err(()))?; receiving_stream
.read(&mut message_size_buf)
.await
.or(Err(()))?;
let message_size: u16 = u16::from_be_bytes(message_size_buf); let message_size: u16 = u16::from_be_bytes(message_size_buf);
@@ -533,26 +567,36 @@ impl TCPSelectiveRoutingHandler {
receiving_stream.read(&mut message).await.or(Err(()))?; receiving_stream.read(&mut message).await.or(Err(()))?;
Ok( Ok(Self {
Self { config: Arc::clone(config),
config: Arc::clone(config), original_message: message,
original_message: message, reply_stream: receiving_stream,
reply_stream: receiving_stream })
}
)
} }
} }
impl SelectiveRoutingHandler for TCPSelectiveRoutingHandler { impl SelectiveRoutingHandler for TCPSelectiveRoutingHandler {
async fn query_upstream(&self) -> Result<Vec<u8>, ()> { async fn query_upstream(&self) -> Result<Vec<u8>, ()> {
let upstream_socket = TcpSocket::new_v4().or(Err(()))?; let upstream_socket = TcpSocket::new_v4().or(Err(()))?;
let mut upstream_stream = upstream_socket.connect(self.config.upstream_resolver.into()).await.or(Err(()))?; let mut upstream_stream = upstream_socket
.connect(self.config.upstream_resolver.into())
.await
.or(Err(()))?;
upstream_stream.write(&(self.original_message.len() as u16).to_be_bytes()).await.or(Err(()))?; upstream_stream
upstream_stream.write(&self.original_message).await.or(Err(()))?; .write(&(self.original_message.len() as u16).to_be_bytes())
.await
.or(Err(()))?;
upstream_stream
.write(&self.original_message)
.await
.or(Err(()))?;
let mut message_size_buf: [u8; 2] = [0, 0]; let mut message_size_buf: [u8; 2] = [0, 0];
upstream_stream.read(&mut message_size_buf).await.or(Err(()))?; upstream_stream
.read(&mut message_size_buf)
.await
.or(Err(()))?;
let message_size: u16 = u16::from_be_bytes(message_size_buf); let message_size: u16 = u16::from_be_bytes(message_size_buf);
@@ -564,7 +608,10 @@ impl SelectiveRoutingHandler for TCPSelectiveRoutingHandler {
} }
async fn reply(&mut self, message: Vec<u8>) -> Result<(), ()> { async fn reply(&mut self, message: Vec<u8>) -> Result<(), ()> {
self.reply_stream.write(&(message.len() as u16).to_be_bytes()).await.or(Err(()))?; self.reply_stream
.write(&(message.len() as u16).to_be_bytes())
.await
.or(Err(()))?;
self.reply_stream.write(&message).await.or(Err(()))?; self.reply_stream.write(&message).await.or(Err(()))?;
Ok(()) Ok(())
@@ -632,7 +679,7 @@ async fn main() -> anyhow::Result<()> {
let config = Arc::new(Config { let config = Arc::new(Config {
listen_addr: LISTEN_ADDR.get().unwrap().parse::<SocketAddrV4>().unwrap(), listen_addr: LISTEN_ADDR.get().unwrap().parse::<SocketAddrV4>().unwrap(),
upstream_resolver: UPSTREAM_DNS.get().unwrap().parse::<SocketAddrV4>().unwrap() upstream_resolver: UPSTREAM_DNS.get().unwrap().parse::<SocketAddrV4>().unwrap(),
}); });
tokio::try_join!( tokio::try_join!(
@@ -648,9 +695,9 @@ async fn main() -> anyhow::Result<()> {
eprintln!("{}", e) eprintln!("{}", e)
} }
}); });
}, }
Err(_) => () Err(_) => (),
} }
} }
} }
@@ -661,17 +708,22 @@ async fn main() -> anyhow::Result<()> {
loop { loop {
match recv_socket_tcp.accept().await { match recv_socket_tcp.accept().await {
Ok((stream, _)) => { Ok((stream, _)) => {
eprintln!("[TCP] Accepted connection from {}", stream.peer_addr().unwrap()); eprintln!(
if let Ok(handler) = TCPSelectiveRoutingHandler::from(&config, stream).await { "[TCP] Accepted connection from {}",
stream.peer_addr().unwrap()
);
if let Ok(handler) =
TCPSelectiveRoutingHandler::from(&config, stream).await
{
tokio::spawn(async { tokio::spawn(async {
if let Err(e) = handle_dns_response(handler).await { if let Err(e) = handle_dns_response(handler).await {
eprintln!("{}", e) eprintln!("{}", e)
} }
}); });
} }
}, }
Err(_) => () Err(_) => (),
} }
} }
} }