Compare commits
3 Commits
8e7220f704
...
d0b5f8c7be
| Author | SHA1 | Date | |
|---|---|---|---|
|
d0b5f8c7be
|
|||
|
7d73832913
|
|||
|
ffd3a5b268
|
+5
-2
@@ -44,7 +44,10 @@ impl<'a, T> Cursor<'a, T> {
|
||||
}
|
||||
}
|
||||
|
||||
pub fn next_array<const N: usize>(&mut self) -> Option<[T; N]> where [T; N]: TryFrom<&'a [T]> {
|
||||
pub fn next_array<const N: usize>(&mut self) -> Option<[T; N]>
|
||||
where
|
||||
[T; N]: TryFrom<&'a [T]>,
|
||||
{
|
||||
Some(self.next_slice(N)?.try_into().ok()?)
|
||||
}
|
||||
|
||||
@@ -57,4 +60,4 @@ impl<'a, T> Cursor<'a, T> {
|
||||
Ok(())
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
+8
-8
@@ -81,19 +81,19 @@ pub struct AnswerIterator<'a> {
|
||||
}
|
||||
|
||||
impl<'a> AnswerIterator<'a> {
|
||||
pub fn from(buf: &'a [u8]) -> Option<Self> {
|
||||
pub fn from(buf: &'a [u8]) -> Result<Self, ()> {
|
||||
let mut cursor = Cursor::from(buf);
|
||||
cursor.seek(4).ok()?;
|
||||
let qdcount = u16::from_be_bytes(cursor.next_array::<2>()?);
|
||||
let ancount = u16::from_be_bytes(cursor.next_array::<2>()?);
|
||||
cursor.forward(4).ok()?;
|
||||
cursor.seek(4).or(Err(()))?;
|
||||
let qdcount = u16::from_be_bytes(cursor.next_array::<2>().ok_or(())?);
|
||||
let ancount = u16::from_be_bytes(cursor.next_array::<2>().ok_or(())?);
|
||||
cursor.forward(4).or(Err(()))?;
|
||||
|
||||
// Skip past the question section
|
||||
for _ in 0..qdcount {
|
||||
dns_name_len(&mut cursor)?;
|
||||
cursor.forward(4).ok()?;
|
||||
dns_name_len(&mut cursor).ok_or(());
|
||||
cursor.forward(4).or(Err(()))?;
|
||||
}
|
||||
Some(Self { cursor, ancount })
|
||||
Ok(Self { cursor, ancount })
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
+414
-298
@@ -6,12 +6,13 @@ mod dns_parser;
|
||||
mod ip_pool;
|
||||
|
||||
use std::collections::HashMap;
|
||||
use std::net::Ipv4Addr;
|
||||
use std::net::{Ipv4Addr, SocketAddr, SocketAddrV4};
|
||||
use std::process::{Command, Output};
|
||||
use std::sync::OnceLock;
|
||||
use std::sync::{Arc, OnceLock};
|
||||
use std::time::{Duration, Instant};
|
||||
|
||||
use ini::Ini;
|
||||
use tokio::net::TcpSocket;
|
||||
use tokio::{
|
||||
io::{AsyncReadExt, AsyncWriteExt},
|
||||
net::{TcpListener, TcpStream, UdpSocket},
|
||||
@@ -20,8 +21,8 @@ use tokio::{
|
||||
|
||||
use crate::cursor::Cursor;
|
||||
use crate::dns_parser::{
|
||||
Answer, AnswerIterator, dns_name_to_parts, dns_parts_to_string,
|
||||
parts_to_dns_name, string_to_dns_name,
|
||||
Answer, AnswerIterator, dns_name_to_parts, dns_parts_to_string, parts_to_dns_name,
|
||||
string_to_dns_name,
|
||||
};
|
||||
use crate::ip_pool::IpPool;
|
||||
|
||||
@@ -121,32 +122,12 @@ async fn teardown_forwarding(
|
||||
let _dnat_output = run_command_string(dnat_command).unwrap();
|
||||
}
|
||||
|
||||
async fn query_upstream_resolvers(original_message: &[u8], tcp: bool) -> anyhow::Result<Vec<u8>> {
|
||||
let mut upstream_reply = Vec::new();
|
||||
|
||||
if tcp {
|
||||
let mut upstream = TcpStream::connect(UPSTREAM_DNS.get().unwrap()).await?;
|
||||
//upstream.write(&(original_message.len() as u16).to_be_bytes()).await?;
|
||||
upstream.write(&original_message).await?;
|
||||
let mut size_buffer = [0u8; 2];
|
||||
upstream.read(&mut size_buffer).await?;
|
||||
let size: u16 = u16::from_be_bytes(size_buffer);
|
||||
upstream_reply.resize(size as usize + 2, 0u8);
|
||||
upstream.read(&mut upstream_reply[2..]).await?;
|
||||
upstream_reply[0] = size_buffer[0];
|
||||
upstream_reply[1] = size_buffer[1];
|
||||
} else {
|
||||
let upstream = UdpSocket::bind("0.0.0.0:0").await?;
|
||||
upstream.connect(UPSTREAM_DNS.get().unwrap()).await?;
|
||||
upstream.send(original_message).await?;
|
||||
upstream_reply.resize(512, 0);
|
||||
upstream.recv_from(&mut upstream_reply[..512]).await?;
|
||||
}
|
||||
|
||||
return Ok(upstream_reply);
|
||||
}
|
||||
|
||||
fn forge_replies(replies: &Vec<Forwarding>, dns_name_string: String, qname_parts: Vec<Vec<u8>>, original_message: &[u8], reply_buf: &mut Vec<u8>, tcp: bool) {
|
||||
fn forge_replies(
|
||||
replies: &Vec<Forwarding>,
|
||||
dns_name_string: String,
|
||||
qname_parts: Vec<Vec<u8>>,
|
||||
original_message: &[u8],
|
||||
) -> Vec<u8> {
|
||||
let reply: [u8; 12] = [
|
||||
0u8,
|
||||
0, // ID
|
||||
@@ -162,8 +143,6 @@ fn forge_replies(replies: &Vec<Forwarding>, dns_name_string: String, qname_parts
|
||||
0, // Arcount
|
||||
];
|
||||
|
||||
let offset = if tcp { 2 } else { 0 };
|
||||
|
||||
let now = Instant::now();
|
||||
|
||||
println!(
|
||||
@@ -180,9 +159,11 @@ fn forge_replies(replies: &Vec<Forwarding>, dns_name_string: String, qname_parts
|
||||
);
|
||||
|
||||
let mut new_reply = reply.clone().to_vec();
|
||||
new_reply[0..2].copy_from_slice(&original_message[offset..][0..=1]);
|
||||
new_reply[0..2].copy_from_slice(&original_message[0..=1]);
|
||||
new_reply[6..8].copy_from_slice(&(replies.len() as u16).to_be_bytes());
|
||||
new_reply.extend_from_slice(&original_message[offset..][12..=12 + qname_parts.iter().map(|p| p.len()+1).sum::<usize>() + 1 + 3]);
|
||||
new_reply.extend_from_slice(
|
||||
&original_message[12..=12 + qname_parts.iter().map(|p| p.len() + 1).sum::<usize>() + 1 + 3],
|
||||
);
|
||||
for reply in replies {
|
||||
new_reply.extend_from_slice(&parts_to_dns_name(&qname_parts));
|
||||
new_reply.extend_from_slice(&[0, 1]);
|
||||
@@ -191,41 +172,58 @@ fn forge_replies(replies: &Vec<Forwarding>, dns_name_string: String, qname_parts
|
||||
new_reply.extend_from_slice(&[0, 4]);
|
||||
new_reply.extend_from_slice(&reply.forged_ip.octets());
|
||||
}
|
||||
if tcp {
|
||||
reply_buf.extend_from_slice(&(new_reply.len() as u16).to_be_bytes());
|
||||
}
|
||||
reply_buf.extend_from_slice(&new_reply);
|
||||
|
||||
new_reply
|
||||
}
|
||||
|
||||
async fn handle_dns_response(buf: &[u8], reply_buf: &mut Vec<u8>, tcp: bool) -> anyhow::Result<()> {
|
||||
|
||||
let offset: usize = if tcp { 2 } else { 0 };
|
||||
let mut cursor = Cursor::from(&buf[offset..]);
|
||||
async fn handle_dns_response(mut handler: impl SelectiveRoutingHandler) -> anyhow::Result<()> {
|
||||
let buf = handler.get_original_message();
|
||||
let mut cursor = Cursor::from(buf);
|
||||
|
||||
// Identify some metadata from the query
|
||||
if cursor.seek(4).is_err() {
|
||||
return Err(anyhow::Error::msg("Failed to seek to QDCount"));
|
||||
}
|
||||
let qdcount = u16::from_be_bytes(cursor.next_array::<2>().ok_or(anyhow::Error::msg("Failed to read QDCount"))?);
|
||||
|
||||
let qdcount = u16::from_be_bytes(
|
||||
cursor
|
||||
.next_array::<2>()
|
||||
.ok_or(anyhow::Error::msg("Failed to read QDCount"))?,
|
||||
);
|
||||
if qdcount != 1 {
|
||||
eprintln!("Got qdcount: {}", qdcount);
|
||||
return Err(anyhow::Error::msg(
|
||||
return Err(anyhow::Error::msg(format!(
|
||||
"Missing question from query. Got qdcount {}",
|
||||
));
|
||||
qdcount
|
||||
)));
|
||||
}
|
||||
if cursor.forward(6).is_err() {
|
||||
return Err(anyhow::Error::msg("Failed to seek to question section"));
|
||||
}
|
||||
let qname_parts =
|
||||
dns_name_to_parts(&mut cursor).ok_or(anyhow::Error::msg("Failed to decode QName."))?;
|
||||
let qtype = u16::from_be_bytes(cursor.next_array::<2>().ok_or(anyhow::Error::msg("Failed to read QType."))?);
|
||||
let qclass = u16::from_be_bytes(cursor.next_array::<2>().ok_or(anyhow::Error::msg("Failed to read QClass."))?);
|
||||
let qtype = u16::from_be_bytes(
|
||||
cursor
|
||||
.next_array::<2>()
|
||||
.ok_or(anyhow::Error::msg("Failed to read QType."))?,
|
||||
);
|
||||
let qclass = u16::from_be_bytes(
|
||||
cursor
|
||||
.next_array::<2>()
|
||||
.ok_or(anyhow::Error::msg("Failed to read QClass."))?,
|
||||
);
|
||||
|
||||
// If the query is for anything other than qclass IN or qtype A, just forward the query upstream
|
||||
if qclass != 1 || qtype != 1 {
|
||||
let upstream_reply = query_upstream_resolvers(buf, tcp).await?;
|
||||
reply_buf.extend_from_slice(&upstream_reply);
|
||||
return Ok(())
|
||||
handler
|
||||
.reply(handler.query_upstream().await.or(Err(anyhow::Error::msg(
|
||||
"Failed to query upstream resolver.",
|
||||
)))?)
|
||||
.await
|
||||
.or(Err(anyhow::Error::msg(
|
||||
"Failed to query upstream resolver.",
|
||||
)))?;
|
||||
return Ok(());
|
||||
}
|
||||
|
||||
let now = Instant::now();
|
||||
@@ -240,158 +238,164 @@ async fn handle_dns_response(buf: &[u8], reply_buf: &mut Vec<u8>, tcp: bool) ->
|
||||
|
||||
// Let's first lookup the qname in the Forwardings to see if we have non-expired answers
|
||||
if entries.len() > 0 && entries.iter().all(|e| e.expires > now) {
|
||||
forge_replies(&entries, dns_name_string, qname_parts, buf, reply_buf, tcp);
|
||||
return Ok(());
|
||||
} else {
|
||||
let upstream_reply = query_upstream_resolvers(buf, tcp).await?;
|
||||
|
||||
// Try an answer from the upstream response that has type A.
|
||||
let a_answers = match AnswerIterator::from(&upstream_reply[offset..]) {
|
||||
Some(answers) => answers
|
||||
.filter(|Answer { rrtype, .. }| *rrtype == 1)
|
||||
.collect::<Vec<Answer>>(),
|
||||
None => {
|
||||
return Err(anyhow::Error::msg(
|
||||
"Failed to extract answers from upstream reply!",
|
||||
));
|
||||
}
|
||||
};
|
||||
|
||||
let mut forge = true;
|
||||
if qtype != 1 {
|
||||
// Only forge for queries with an A qtype
|
||||
eprintln!("Not forging due non-A type question.");
|
||||
forge = false;
|
||||
}
|
||||
|
||||
if a_answers.len() == 0 {
|
||||
// If no A type answer, don't forge
|
||||
eprintln!("Not forging due to no returned A type answers.");
|
||||
forge = false;
|
||||
}
|
||||
|
||||
if a_answers.iter().any(|a| a.rdata.len() != 4) {
|
||||
eprintln!("Not forging due to malformed A type answer.",);
|
||||
forge = false;
|
||||
}
|
||||
|
||||
if !FwmarkConfigMap
|
||||
.lock()
|
||||
handler
|
||||
.reply(forge_replies(&entries, dns_name_string, qname_parts, buf))
|
||||
.await
|
||||
.get(&parts_to_dns_name(&qname_parts))
|
||||
.is_some()
|
||||
{
|
||||
eprintln!("Not forging due to non-matching qname.");
|
||||
forge = false;
|
||||
}
|
||||
|
||||
if forge {
|
||||
// Normalise a_answers so we're working with Ipv4Addr
|
||||
let normalised_answers: Vec<(Vec<Vec<u8>>, u32, Ipv4Addr)> = a_answers
|
||||
.iter()
|
||||
.map(|answer| {
|
||||
(
|
||||
answer.name.clone(),
|
||||
answer.ttl,
|
||||
Ipv4Addr::from(
|
||||
<Vec<u8> as TryInto<[u8; 4]>>::try_into(answer.rdata.clone()).unwrap(),
|
||||
),
|
||||
)
|
||||
})
|
||||
.collect::<_>();
|
||||
|
||||
let mut replies: Vec<Forwarding> = Vec::new();
|
||||
|
||||
// Determine if we need to create or renew our entries
|
||||
if entries.len() > 0 && entries.iter().all(|e| e.expires > now) {
|
||||
println!("Found not expired forwardings for {}", dns_name_string);
|
||||
replies.extend(entries.clone() as Vec<Forwarding>);
|
||||
} else {
|
||||
// We want to identify which of our A answers already exist in the ForwardingMap
|
||||
let existing_entries = entries
|
||||
.iter()
|
||||
.filter(|e| {
|
||||
normalised_answers
|
||||
.iter()
|
||||
.any(|(_, _, addr)| e.original_ip == *addr)
|
||||
})
|
||||
.collect::<Vec<_>>();
|
||||
|
||||
// Let's also identify which entries don't match any of the replies
|
||||
let nonexisting_entries = entries
|
||||
.iter()
|
||||
.filter(|e| {
|
||||
!normalised_answers
|
||||
.iter()
|
||||
.any(|(_, _, addr)| e.original_ip == *addr)
|
||||
})
|
||||
.collect::<Vec<_>>();
|
||||
|
||||
// And now let's find the replies that don't match any of the current entries
|
||||
let new_answers = normalised_answers
|
||||
.iter()
|
||||
.filter(|answer| entries.iter().all(|e| e.original_ip != answer.2))
|
||||
.collect::<Vec<_>>();
|
||||
|
||||
// Acquire the forwarding map
|
||||
let mut forwarding_map = ForwardingMap.lock().await;
|
||||
|
||||
// Remove all non-existing entries
|
||||
for entry in nonexisting_entries {
|
||||
println!(
|
||||
"Removed forwarding for {} with real IP {} / forged IP {} as the upstream no longer returns this value",
|
||||
dns_name_string, entry.original_ip, entry.forged_ip
|
||||
);
|
||||
teardown_forwarding(&mut forwarding_map, &dns_name, entry.original_ip).await;
|
||||
}
|
||||
|
||||
// Add new answers
|
||||
for (_, ttl, original_ip) in new_answers {
|
||||
let forged_ip = IpAllocator.lock().await.acquire().unwrap();
|
||||
setup_forwarding(
|
||||
&mut forwarding_map,
|
||||
&dns_name,
|
||||
*ttl,
|
||||
*original_ip,
|
||||
forged_ip,
|
||||
)
|
||||
.await;
|
||||
println!(
|
||||
"Added forwarding for {} with real IP {} / forged IP {}",
|
||||
dns_name_string, original_ip, forged_ip
|
||||
);
|
||||
}
|
||||
|
||||
// For all the existing entries, update their TTL if they're expired
|
||||
for entry in existing_entries.iter() {
|
||||
println!(
|
||||
"Updating TTL of existing entry {}/{}",
|
||||
dns_name_string, entry.forged_ip
|
||||
);
|
||||
if entry.expires < now {
|
||||
teardown_forwarding(&mut forwarding_map, &dns_name, entry.original_ip).await;
|
||||
setup_forwarding(
|
||||
&mut forwarding_map,
|
||||
&dns_name,
|
||||
entry.ttl,
|
||||
entry.original_ip,
|
||||
entry.forged_ip,
|
||||
)
|
||||
.await;
|
||||
}
|
||||
}
|
||||
|
||||
let entries = forwarding_map.get(&dns_name).unwrap().clone();
|
||||
replies.extend(entries);
|
||||
}
|
||||
|
||||
forge_replies(&replies, dns_name_string, qname_parts, buf, reply_buf, tcp);
|
||||
} else {
|
||||
reply_buf.extend_from_slice(&upstream_reply);
|
||||
}
|
||||
|
||||
Ok(())
|
||||
.or(Err(anyhow::Error::msg("Failed to reply!")))?;
|
||||
return Ok(());
|
||||
}
|
||||
|
||||
let upstream_reply = handler.query_upstream().await.or(Err(anyhow::Error::msg(
|
||||
"Failed to query upstream resolver.",
|
||||
)))?;
|
||||
|
||||
// Try an answer from the upstream response that has type A.
|
||||
let answer_iterator = AnswerIterator::from(&upstream_reply).or(Err(anyhow::Error::msg(
|
||||
"Failed to extract answers from upstream reply!",
|
||||
)))?;
|
||||
let a_answers = answer_iterator
|
||||
.filter(|Answer { rrtype, .. }| *rrtype == 1)
|
||||
.collect::<Vec<Answer>>();
|
||||
|
||||
let mut should_forge = true;
|
||||
if qtype != 1 {
|
||||
// Only forge for queries with an A qtype
|
||||
eprintln!("Not forging due non-A type question.");
|
||||
should_forge = false;
|
||||
}
|
||||
|
||||
if a_answers.len() == 0 {
|
||||
// If no A type answer, don't forge
|
||||
eprintln!("Not forging due to no returned A type answers.");
|
||||
should_forge = false;
|
||||
}
|
||||
|
||||
if a_answers.iter().any(|a| a.rdata.len() != 4) {
|
||||
eprintln!("Not forging due to malformed A type answer.",);
|
||||
should_forge = false;
|
||||
}
|
||||
|
||||
if !FwmarkConfigMap
|
||||
.lock()
|
||||
.await
|
||||
.get(&parts_to_dns_name(&qname_parts))
|
||||
.is_some()
|
||||
{
|
||||
eprintln!("Not forging due to non-matching qname.");
|
||||
should_forge = false;
|
||||
}
|
||||
|
||||
if !should_forge {
|
||||
handler
|
||||
.reply(upstream_reply)
|
||||
.await
|
||||
.or(Err(anyhow::Error::msg("Failed to reply.")))?;
|
||||
return Ok(());
|
||||
}
|
||||
|
||||
// Normalise a_answers so we're working with Ipv4Addr
|
||||
let normalised_answers: Vec<(Vec<Vec<u8>>, u32, Ipv4Addr)> = a_answers
|
||||
.into_iter()
|
||||
.map(|answer| {
|
||||
(
|
||||
answer.name,
|
||||
answer.ttl,
|
||||
Ipv4Addr::from(<Vec<u8> as TryInto<[u8; 4]>>::try_into(answer.rdata).unwrap()),
|
||||
)
|
||||
})
|
||||
.collect::<_>();
|
||||
|
||||
let mut replies: Vec<Forwarding> = Vec::new();
|
||||
|
||||
// Determine if we need to create or renew our entries
|
||||
if entries.len() > 0 && entries.iter().all(|e| e.expires > now) {
|
||||
println!("Found not expired forwardings for {}", dns_name_string);
|
||||
replies.extend(entries.clone() as Vec<Forwarding>);
|
||||
} else {
|
||||
// We want to identify which of our A answers already exist in the ForwardingMap
|
||||
let existing_entries = entries
|
||||
.iter()
|
||||
.filter(|e| {
|
||||
normalised_answers
|
||||
.iter()
|
||||
.any(|(_, _, addr)| e.original_ip == *addr)
|
||||
})
|
||||
.collect::<Vec<_>>();
|
||||
|
||||
// Let's also identify which entries don't match any of the replies
|
||||
let nonexisting_entries = entries
|
||||
.iter()
|
||||
.filter(|e| {
|
||||
!normalised_answers
|
||||
.iter()
|
||||
.any(|(_, _, addr)| e.original_ip == *addr)
|
||||
})
|
||||
.collect::<Vec<_>>();
|
||||
|
||||
// And now let's find the replies that don't match any of the current entries
|
||||
let new_answers = normalised_answers
|
||||
.iter()
|
||||
.filter(|answer| entries.iter().all(|e| e.original_ip != answer.2))
|
||||
.collect::<Vec<_>>();
|
||||
|
||||
// Acquire the forwarding map
|
||||
let mut forwarding_map = ForwardingMap.lock().await;
|
||||
|
||||
// Remove all non-existing entries
|
||||
for entry in nonexisting_entries {
|
||||
println!(
|
||||
"Removed forwarding for {} with real IP {} / forged IP {} as the upstream no longer returns this value",
|
||||
dns_name_string, entry.original_ip, entry.forged_ip
|
||||
);
|
||||
teardown_forwarding(&mut forwarding_map, &dns_name, entry.original_ip).await;
|
||||
}
|
||||
|
||||
// Add new answers
|
||||
for (_, ttl, original_ip) in new_answers {
|
||||
let forged_ip = IpAllocator.lock().await.acquire().unwrap();
|
||||
setup_forwarding(
|
||||
&mut forwarding_map,
|
||||
&dns_name,
|
||||
*ttl,
|
||||
*original_ip,
|
||||
forged_ip,
|
||||
)
|
||||
.await;
|
||||
println!(
|
||||
"Added forwarding for {} with real IP {} / forged IP {}",
|
||||
dns_name_string, original_ip, forged_ip
|
||||
);
|
||||
}
|
||||
|
||||
// For all the existing entries, update their TTL if they're expired
|
||||
for entry in existing_entries.iter() {
|
||||
println!(
|
||||
"Updating TTL of existing entry {}/{}",
|
||||
dns_name_string, entry.forged_ip
|
||||
);
|
||||
if entry.expires < now {
|
||||
teardown_forwarding(&mut forwarding_map, &dns_name, entry.original_ip).await;
|
||||
setup_forwarding(
|
||||
&mut forwarding_map,
|
||||
&dns_name,
|
||||
entry.ttl,
|
||||
entry.original_ip,
|
||||
entry.forged_ip,
|
||||
)
|
||||
.await;
|
||||
}
|
||||
}
|
||||
|
||||
let entries = forwarding_map.get(&dns_name).unwrap().clone();
|
||||
replies.extend(entries);
|
||||
}
|
||||
|
||||
handler
|
||||
.reply(forge_replies(&replies, dns_name_string, qname_parts, buf))
|
||||
.await
|
||||
.or(Err(anyhow::Error::msg("Failed to reply.")))?;
|
||||
|
||||
Ok(())
|
||||
}
|
||||
|
||||
#[derive(Clone, Debug)]
|
||||
@@ -404,38 +408,31 @@ struct Forwarding {
|
||||
|
||||
enum FwmarkDomainMatchType {
|
||||
Exact,
|
||||
Suffix
|
||||
Suffix,
|
||||
}
|
||||
|
||||
struct FwmarkConfig {
|
||||
entries: Vec<(FwmarkDomainMatchType, Vec<Vec<u8>>, u32)>
|
||||
entries: Vec<(FwmarkDomainMatchType, Vec<Vec<u8>>, u32)>,
|
||||
}
|
||||
|
||||
impl FwmarkConfig {
|
||||
fn new() -> Self {
|
||||
Self {
|
||||
entries: Vec::new()
|
||||
entries: Vec::new(),
|
||||
}
|
||||
}
|
||||
|
||||
fn insert(&mut self, dns_name: Vec<u8>, fwmark: u32) {
|
||||
self.entries.push(
|
||||
(
|
||||
FwmarkDomainMatchType::Exact,
|
||||
dns_name_to_parts(&mut Cursor::from(dns_name.as_slice())).unwrap(),
|
||||
fwmark
|
||||
)
|
||||
);
|
||||
self.entries.push((
|
||||
FwmarkDomainMatchType::Exact,
|
||||
dns_name_to_parts(&mut Cursor::from(dns_name.as_slice())).unwrap(),
|
||||
fwmark,
|
||||
));
|
||||
}
|
||||
|
||||
fn insert_wildcard(&mut self, suffix_parts: Vec<Vec<u8>>, fwmark: u32) {
|
||||
self.entries.push(
|
||||
(
|
||||
FwmarkDomainMatchType::Suffix,
|
||||
suffix_parts,
|
||||
fwmark
|
||||
)
|
||||
);
|
||||
self.entries
|
||||
.push((FwmarkDomainMatchType::Suffix, suffix_parts, fwmark));
|
||||
}
|
||||
|
||||
fn get(&self, dns_name: &Vec<u8>) -> Option<u32> {
|
||||
@@ -445,12 +442,12 @@ impl FwmarkConfig {
|
||||
match match_type {
|
||||
FwmarkDomainMatchType::Exact => {
|
||||
if parts == *_parts {
|
||||
return Some(*fwmark)
|
||||
return Some(*fwmark);
|
||||
}
|
||||
},
|
||||
}
|
||||
FwmarkDomainMatchType::Suffix => {
|
||||
if parts.ends_with(_parts.as_slice()) {
|
||||
return Some(*fwmark)
|
||||
return Some(*fwmark);
|
||||
}
|
||||
}
|
||||
}
|
||||
@@ -471,86 +468,157 @@ static COMMAND_PREFIX: OnceLock<String> = OnceLock::new();
|
||||
static UPSTREAM_DNS: OnceLock<String> = OnceLock::new();
|
||||
static LISTEN_ADDR: OnceLock<String> = OnceLock::new();
|
||||
|
||||
async fn tcp_handler(listener: TcpListener) {
|
||||
loop {
|
||||
let accepted = listener.accept().await;
|
||||
if let Err(e) = &accepted {
|
||||
eprintln!("[TCP] Failed to accept TCP socket with error: {:?}", e);
|
||||
struct Config {
|
||||
listen_addr: SocketAddrV4,
|
||||
upstream_resolver: SocketAddrV4,
|
||||
}
|
||||
|
||||
trait SelectiveRoutingHandler {
|
||||
async fn query_upstream(&self) -> Result<Vec<u8>, ()>;
|
||||
async fn reply(&mut self, message: Vec<u8>) -> Result<(), ()>;
|
||||
|
||||
fn get_original_message(&self) -> &[u8];
|
||||
}
|
||||
|
||||
struct UDPSelectiveRoutingHandler {
|
||||
config: Arc<Config>,
|
||||
original_message: Vec<u8>,
|
||||
reply_address: Ipv4Addr,
|
||||
reply_port: u16,
|
||||
reply_socket: Arc<UdpSocket>,
|
||||
}
|
||||
|
||||
impl UDPSelectiveRoutingHandler {
|
||||
pub async fn from(config: &Arc<Config>, receiving_socket: &Arc<UdpSocket>) -> Result<Self, ()> {
|
||||
let mut recv_buf = vec![0; 512];
|
||||
|
||||
let reply_address;
|
||||
let reply_port;
|
||||
|
||||
match receiving_socket.recv_from(&mut recv_buf).await {
|
||||
Ok((_, SocketAddr::V4(addr))) => {
|
||||
reply_address = *addr.ip();
|
||||
reply_port = addr.port();
|
||||
}
|
||||
Ok(_) => unimplemented!(),
|
||||
Err(_) => return Err(()),
|
||||
}
|
||||
|
||||
let (mut socket, addr) = accepted.unwrap();
|
||||
let mut size_buf = [0u8; 2];
|
||||
let Ok(2) = socket.read(&mut size_buf).await else {
|
||||
eprintln!("[TCP] Failed to read message size from {}", addr);
|
||||
continue;
|
||||
};
|
||||
let message_size = u16::from_be_bytes(size_buf) as usize;
|
||||
let mut message_buffer = Vec::new();
|
||||
message_buffer.resize(message_size + 2, 0);
|
||||
match socket.read(&mut message_buffer[2..]).await {
|
||||
Ok(n) => {
|
||||
if n != message_size {
|
||||
eprintln!(
|
||||
"[TCP] Received too few bytes than expected from {}. Message size indicated {}, read {}",
|
||||
addr, message_size, n
|
||||
);
|
||||
continue;
|
||||
}
|
||||
}
|
||||
Err(e) => {
|
||||
eprintln!(
|
||||
"[TCP] Failed to read message from {} with error: {:?}",
|
||||
addr, e
|
||||
);
|
||||
continue;
|
||||
}
|
||||
}
|
||||
message_buffer[0] = size_buf[0];
|
||||
message_buffer[1] = size_buf[1];
|
||||
let mut reply = Vec::new();
|
||||
if let Err(e) = handle_dns_response(&message_buffer, &mut reply, true).await {
|
||||
eprintln!(
|
||||
"[TCP] Received error when handling response for {}: {:?}",
|
||||
addr, e
|
||||
);
|
||||
continue;
|
||||
}
|
||||
if let Err(e) = socket.write(&reply).await {
|
||||
eprintln!(
|
||||
"[TCP] Received error when sending response to {}: {:?}",
|
||||
addr, e
|
||||
);
|
||||
continue;
|
||||
}
|
||||
Ok(Self {
|
||||
config: Arc::clone(config),
|
||||
original_message: recv_buf,
|
||||
reply_address,
|
||||
reply_port,
|
||||
reply_socket: Arc::clone(receiving_socket),
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
async fn udp_handler(socket: UdpSocket) {
|
||||
loop {
|
||||
let mut message_buffer = [0u8; 512];
|
||||
match socket.recv_from(&mut message_buffer).await {
|
||||
Ok((_, addr)) => {
|
||||
let mut reply = Vec::new();
|
||||
if let Err(e) = handle_dns_response(&message_buffer, &mut reply, false).await {
|
||||
eprintln!(
|
||||
"[UDP] Received error when handling response for {}: {:?}",
|
||||
addr, e
|
||||
);
|
||||
continue;
|
||||
}
|
||||
if let Err(e) = socket.send_to(&reply, addr).await {
|
||||
eprintln!(
|
||||
"[UDP] Received error when sending response to {}: {:?}",
|
||||
addr, e
|
||||
);
|
||||
continue;
|
||||
}
|
||||
}
|
||||
Err(e) => {
|
||||
eprintln!("[UDP] Failed to read message with error: {:?}", e);
|
||||
continue;
|
||||
}
|
||||
}
|
||||
impl SelectiveRoutingHandler for UDPSelectiveRoutingHandler {
|
||||
async fn query_upstream(&self) -> Result<Vec<u8>, ()> {
|
||||
let query_upstream_socket = UdpSocket::bind("0.0.0.0:0").await.or(Err(()))?;
|
||||
|
||||
query_upstream_socket
|
||||
.send_to(&self.original_message, self.config.upstream_resolver)
|
||||
.await
|
||||
.or(Err(()))?;
|
||||
|
||||
let mut recv_buf = vec![0; 512];
|
||||
|
||||
query_upstream_socket
|
||||
.recv(&mut recv_buf)
|
||||
.await
|
||||
.or(Err(()))?;
|
||||
|
||||
Ok(recv_buf)
|
||||
}
|
||||
|
||||
async fn reply(&mut self, message: Vec<u8>) -> Result<(), ()> {
|
||||
let _ = self
|
||||
.reply_socket
|
||||
.send_to(&message, (self.reply_address, self.reply_port))
|
||||
.await;
|
||||
|
||||
Ok(())
|
||||
}
|
||||
|
||||
fn get_original_message(&self) -> &[u8] {
|
||||
&self.original_message
|
||||
}
|
||||
}
|
||||
|
||||
struct TCPSelectiveRoutingHandler {
|
||||
config: Arc<Config>,
|
||||
original_message: Vec<u8>,
|
||||
reply_stream: TcpStream,
|
||||
}
|
||||
|
||||
impl TCPSelectiveRoutingHandler {
|
||||
async fn from(config: &Arc<Config>, mut receiving_stream: TcpStream) -> Result<Self, ()> {
|
||||
let mut message_size_buf: [u8; 2] = [0, 0];
|
||||
receiving_stream
|
||||
.read(&mut message_size_buf)
|
||||
.await
|
||||
.or(Err(()))?;
|
||||
|
||||
let message_size: u16 = u16::from_be_bytes(message_size_buf);
|
||||
|
||||
let mut message = vec![0u8; message_size as usize];
|
||||
|
||||
receiving_stream.read(&mut message).await.or(Err(()))?;
|
||||
|
||||
Ok(Self {
|
||||
config: Arc::clone(config),
|
||||
original_message: message,
|
||||
reply_stream: receiving_stream,
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
impl SelectiveRoutingHandler for TCPSelectiveRoutingHandler {
|
||||
async fn query_upstream(&self) -> Result<Vec<u8>, ()> {
|
||||
let upstream_socket = TcpSocket::new_v4().or(Err(()))?;
|
||||
let mut upstream_stream = upstream_socket
|
||||
.connect(self.config.upstream_resolver.into())
|
||||
.await
|
||||
.or(Err(()))?;
|
||||
|
||||
upstream_stream
|
||||
.write(&(self.original_message.len() as u16).to_be_bytes())
|
||||
.await
|
||||
.or(Err(()))?;
|
||||
upstream_stream
|
||||
.write(&self.original_message)
|
||||
.await
|
||||
.or(Err(()))?;
|
||||
|
||||
let mut message_size_buf: [u8; 2] = [0, 0];
|
||||
upstream_stream
|
||||
.read(&mut message_size_buf)
|
||||
.await
|
||||
.or(Err(()))?;
|
||||
|
||||
let message_size: u16 = u16::from_be_bytes(message_size_buf);
|
||||
|
||||
let mut message = vec![0; message_size as usize];
|
||||
|
||||
upstream_stream.read(&mut message).await.or(Err(()))?;
|
||||
|
||||
Ok(message)
|
||||
}
|
||||
|
||||
async fn reply(&mut self, message: Vec<u8>) -> Result<(), ()> {
|
||||
self.reply_stream
|
||||
.write(&(message.len() as u16).to_be_bytes())
|
||||
.await
|
||||
.or(Err(()))?;
|
||||
self.reply_stream.write(&message).await.or(Err(()))?;
|
||||
|
||||
Ok(())
|
||||
}
|
||||
|
||||
fn get_original_message(&self) -> &[u8] {
|
||||
&self.original_message
|
||||
}
|
||||
}
|
||||
|
||||
@@ -607,11 +675,59 @@ async fn main() -> anyhow::Result<()> {
|
||||
}
|
||||
|
||||
let recv_socket_tcp = TcpListener::bind(LISTEN_ADDR.get().unwrap()).await?;
|
||||
let recv_socket_udp = UdpSocket::bind(LISTEN_ADDR.get().unwrap()).await?;
|
||||
let recv_socket_udp = Arc::new(UdpSocket::bind(LISTEN_ADDR.get().unwrap()).await?);
|
||||
|
||||
let config = Arc::new(Config {
|
||||
listen_addr: LISTEN_ADDR.get().unwrap().parse::<SocketAddrV4>().unwrap(),
|
||||
upstream_resolver: UPSTREAM_DNS.get().unwrap().parse::<SocketAddrV4>().unwrap(),
|
||||
});
|
||||
|
||||
tokio::try_join!(
|
||||
tokio::spawn(tcp_handler(recv_socket_tcp)),
|
||||
tokio::spawn(udp_handler(recv_socket_udp)),
|
||||
tokio::spawn({
|
||||
let config = Arc::clone(&config);
|
||||
let recv_socket_udp = Arc::clone(&recv_socket_udp);
|
||||
async move {
|
||||
loop {
|
||||
match UDPSelectiveRoutingHandler::from(&config, &recv_socket_udp).await {
|
||||
Ok(handler) => {
|
||||
tokio::spawn(async {
|
||||
if let Err(e) = handle_dns_response(handler).await {
|
||||
eprintln!("{}", e)
|
||||
}
|
||||
});
|
||||
}
|
||||
|
||||
Err(_) => (),
|
||||
}
|
||||
}
|
||||
}
|
||||
}),
|
||||
tokio::spawn({
|
||||
let config = Arc::clone(&config);
|
||||
async move {
|
||||
loop {
|
||||
match recv_socket_tcp.accept().await {
|
||||
Ok((stream, _)) => {
|
||||
eprintln!(
|
||||
"[TCP] Accepted connection from {}",
|
||||
stream.peer_addr().unwrap()
|
||||
);
|
||||
if let Ok(handler) =
|
||||
TCPSelectiveRoutingHandler::from(&config, stream).await
|
||||
{
|
||||
tokio::spawn(async {
|
||||
if let Err(e) = handle_dns_response(handler).await {
|
||||
eprintln!("{}", e)
|
||||
}
|
||||
});
|
||||
}
|
||||
}
|
||||
|
||||
Err(_) => (),
|
||||
}
|
||||
}
|
||||
}
|
||||
}),
|
||||
)?;
|
||||
Ok(())
|
||||
}
|
||||
|
||||
Reference in New Issue
Block a user