Squash commits.
This commit is contained in:
@@ -0,0 +1,60 @@
|
||||
#[derive(Clone)]
|
||||
pub struct Cursor<'a, T> {
|
||||
buf: &'a [T],
|
||||
index: usize,
|
||||
}
|
||||
|
||||
impl<'a, T> Cursor<'a, T> {
|
||||
pub fn from(buf: &'a [T]) -> Self {
|
||||
Self { buf, index: 0 }
|
||||
}
|
||||
|
||||
pub fn index(&mut self) -> usize {
|
||||
self.index
|
||||
}
|
||||
|
||||
pub fn next(&mut self) -> Option<&T> {
|
||||
let next_index = self.index + 1;
|
||||
if next_index >= self.buf.len() {
|
||||
None
|
||||
} else {
|
||||
let v = &self.buf[self.index];
|
||||
self.index = next_index;
|
||||
Some(v)
|
||||
}
|
||||
}
|
||||
|
||||
pub fn seek(&mut self, location: usize) -> Result<(), ()> {
|
||||
if location >= self.buf.len() {
|
||||
Err(())
|
||||
} else {
|
||||
self.index = location;
|
||||
Ok(())
|
||||
}
|
||||
}
|
||||
|
||||
pub fn next_slice(&mut self, amount: usize) -> Option<&'a [T]> {
|
||||
let next_index = self.index + amount;
|
||||
if next_index >= self.buf.len() {
|
||||
None
|
||||
} else {
|
||||
let slice = &self.buf[self.index..next_index];
|
||||
self.index = next_index;
|
||||
Some(slice)
|
||||
}
|
||||
}
|
||||
|
||||
pub fn next_array<const N: usize>(&mut self) -> Option<[T; N]> where [T; N]: TryFrom<&'a [T]> {
|
||||
Some(self.next_slice(N)?.try_into().ok()?)
|
||||
}
|
||||
|
||||
pub fn forward(&mut self, amount: usize) -> Result<(), ()> {
|
||||
let next_index = self.index + amount;
|
||||
if next_index >= self.buf.len() {
|
||||
Err(())
|
||||
} else {
|
||||
self.index = next_index;
|
||||
Ok(())
|
||||
}
|
||||
}
|
||||
}
|
||||
@@ -0,0 +1,131 @@
|
||||
use crate::cursor::Cursor;
|
||||
|
||||
pub fn dns_name_to_parts(cursor: &mut Cursor<u8>) -> Option<Vec<Vec<u8>>> {
|
||||
let mut parts: Vec<Vec<u8>> = Vec::new();
|
||||
let mut ptr_depth = 0;
|
||||
let mut restore_point = None;
|
||||
while let Some(&(mut byte)) = cursor.next() {
|
||||
if byte == 0 {
|
||||
break;
|
||||
}
|
||||
if byte >= 192 {
|
||||
if let None = restore_point {
|
||||
restore_point = Some(cursor.index() + 1);
|
||||
}
|
||||
if ptr_depth >= 16 {
|
||||
return None;
|
||||
}
|
||||
let ptr_lsb = *cursor.next()?;
|
||||
cursor
|
||||
.seek(u16::from_be_bytes([byte & 0b0011_1111, ptr_lsb]) as usize)
|
||||
.ok()?;
|
||||
byte = *cursor.next()?;
|
||||
ptr_depth += 1;
|
||||
}
|
||||
parts.push(cursor.next_slice(byte as usize)?.to_vec());
|
||||
}
|
||||
if let Some(position) = restore_point {
|
||||
cursor.seek(position).ok()?;
|
||||
}
|
||||
Some(parts)
|
||||
}
|
||||
|
||||
pub fn dns_name_len(buffer: &mut Cursor<u8>) -> Option<usize> {
|
||||
let mut length = 0;
|
||||
while let Some(&byte) = buffer.next() {
|
||||
if byte == 0 {
|
||||
return Some(length + 1);
|
||||
}
|
||||
if byte >= 192 {
|
||||
return Some(length + 2);
|
||||
}
|
||||
buffer.forward(byte as usize).ok()?;
|
||||
length += byte as usize + 1;
|
||||
}
|
||||
Some(length)
|
||||
}
|
||||
|
||||
pub fn parts_to_dns_name(parts: &Vec<Vec<u8>>) -> Vec<u8> {
|
||||
let mut result = Vec::new();
|
||||
for part in parts {
|
||||
result.push(part.len() as u8);
|
||||
result.extend_from_slice(&part);
|
||||
}
|
||||
result.push(0u8);
|
||||
result
|
||||
}
|
||||
|
||||
pub fn dns_parts_to_string(parts: &Vec<Vec<u8>>) -> String {
|
||||
let mut result = String::new();
|
||||
for part in parts {
|
||||
result.push_str(&String::from_utf8_lossy(&part));
|
||||
result.push('.');
|
||||
}
|
||||
return result;
|
||||
}
|
||||
|
||||
pub fn string_to_dns_name(string: String) -> Vec<u8> {
|
||||
let mut result = Vec::new();
|
||||
|
||||
for part in string.split('.') {
|
||||
result.push(part.len() as u8);
|
||||
result.extend_from_slice(part.as_bytes());
|
||||
}
|
||||
result.push(0u8);
|
||||
return result;
|
||||
}
|
||||
|
||||
pub struct AnswerIterator<'a> {
|
||||
cursor: Cursor<'a, u8>,
|
||||
ancount: u16,
|
||||
}
|
||||
|
||||
impl<'a> AnswerIterator<'a> {
|
||||
pub fn from(buf: &'a [u8]) -> Option<Self> {
|
||||
let mut cursor = Cursor::from(buf);
|
||||
cursor.seek(4).ok()?;
|
||||
let qdcount = u16::from_be_bytes(cursor.next_array::<2>()?);
|
||||
let ancount = u16::from_be_bytes(cursor.next_array::<2>()?);
|
||||
cursor.forward(4).ok()?;
|
||||
|
||||
// Skip past the question section
|
||||
for _ in 0..qdcount {
|
||||
dns_name_len(&mut cursor)?;
|
||||
cursor.forward(4).ok()?;
|
||||
}
|
||||
Some(Self { cursor, ancount })
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(PartialEq, Debug)]
|
||||
pub struct Answer {
|
||||
pub name: Vec<Vec<u8>>,
|
||||
pub rrtype: u16,
|
||||
pub rrclass: u16,
|
||||
pub ttl: u32,
|
||||
pub rdata: Vec<u8>,
|
||||
}
|
||||
|
||||
impl<'a> Iterator for AnswerIterator<'a> {
|
||||
type Item = Answer;
|
||||
|
||||
fn next(&mut self) -> Option<Self::Item> {
|
||||
if self.ancount == 0 {
|
||||
return None;
|
||||
}
|
||||
self.ancount -= 1;
|
||||
let name = dns_name_to_parts(&mut self.cursor)?;
|
||||
let rrtype = u16::from_be_bytes(self.cursor.next_array::<2>()?);
|
||||
let rrclass = u16::from_be_bytes(self.cursor.next_array::<2>()?);
|
||||
let ttl = u32::from_be_bytes(self.cursor.next_array::<4>()?);
|
||||
let rdlength = u16::from_be_bytes(self.cursor.next_array::<2>()?) as usize;
|
||||
let rdata = self.cursor.next_slice(rdlength)?.to_vec();
|
||||
Some(Answer {
|
||||
name,
|
||||
rrtype,
|
||||
rrclass,
|
||||
ttl,
|
||||
rdata,
|
||||
})
|
||||
}
|
||||
}
|
||||
@@ -0,0 +1,47 @@
|
||||
use std::collections::{HashSet, VecDeque};
|
||||
use std::net::Ipv4Addr;
|
||||
|
||||
pub struct IpPool {
|
||||
free_ips: VecDeque<Ipv4Addr>,
|
||||
allocated_ips: HashSet<Ipv4Addr>,
|
||||
}
|
||||
|
||||
impl IpPool {
|
||||
pub fn new(base_addr: Ipv4Addr, subnet_prefix_size: u8) -> Option<Self> {
|
||||
if subnet_prefix_size == 32 {
|
||||
return None;
|
||||
}
|
||||
let subnet_prefix_mask: u32 = u32::MAX << (32 - subnet_prefix_size);
|
||||
let base_addr_int = <Ipv4Addr as Into<u32>>::into(base_addr);
|
||||
if base_addr_int & !subnet_prefix_mask != 0 {
|
||||
return None;
|
||||
}
|
||||
let last_address_number = u32::MAX & !subnet_prefix_mask;
|
||||
let mut pool = VecDeque::with_capacity(last_address_number as usize);
|
||||
for number in 0..=(last_address_number as usize) {
|
||||
pool.push_back(Ipv4Addr::from(base_addr_int + (number as u32)));
|
||||
}
|
||||
|
||||
Some(Self {
|
||||
free_ips: pool,
|
||||
allocated_ips: HashSet::new(),
|
||||
})
|
||||
}
|
||||
|
||||
pub fn acquire(&mut self) -> Option<Ipv4Addr> {
|
||||
match self.free_ips.pop_front() {
|
||||
ip @ Some(addr) => {
|
||||
self.allocated_ips.insert(addr);
|
||||
ip
|
||||
}
|
||||
None => None,
|
||||
}
|
||||
}
|
||||
|
||||
pub fn release(&mut self, addr: Ipv4Addr) {
|
||||
if self.allocated_ips.contains(&addr) {
|
||||
self.allocated_ips.remove(&addr);
|
||||
self.free_ips.push_back(addr);
|
||||
}
|
||||
}
|
||||
}
|
||||
+551
@@ -0,0 +1,551 @@
|
||||
#[macro_use]
|
||||
extern crate lazy_static;
|
||||
|
||||
mod cursor;
|
||||
mod dns_parser;
|
||||
mod ip_pool;
|
||||
|
||||
use std::collections::HashMap;
|
||||
use std::net::Ipv4Addr;
|
||||
use std::process::{Command, Output};
|
||||
use std::sync::OnceLock;
|
||||
use std::time::{Duration, Instant};
|
||||
|
||||
use ini::Ini;
|
||||
use tokio::{
|
||||
io::{AsyncReadExt, AsyncWriteExt},
|
||||
net::{TcpListener, TcpStream, UdpSocket},
|
||||
sync::Mutex,
|
||||
};
|
||||
|
||||
use crate::cursor::Cursor;
|
||||
use crate::dns_parser::{
|
||||
Answer, AnswerIterator, dns_name_to_parts, dns_parts_to_string,
|
||||
parts_to_dns_name, string_to_dns_name,
|
||||
};
|
||||
use crate::ip_pool::IpPool;
|
||||
|
||||
fn run_command_string(command: String) -> Result<Output, std::io::Error> {
|
||||
let command_vec = {
|
||||
if COMMAND_PREFIX.get().unwrap().eq("") {
|
||||
command.split(' ').collect::<Vec<&str>>()
|
||||
} else {
|
||||
COMMAND_PREFIX
|
||||
.get()
|
||||
.unwrap()
|
||||
.split(' ')
|
||||
.chain(command.split(' '))
|
||||
.collect::<Vec<&str>>()
|
||||
}
|
||||
};
|
||||
|
||||
Command::new(command_vec[0])
|
||||
.args(&command_vec[1..])
|
||||
.output()
|
||||
}
|
||||
|
||||
async fn setup_forwarding(
|
||||
forwarding_map: &mut HashMap<Vec<u8>, Vec<Forwarding>>,
|
||||
dns_name: &Vec<u8>,
|
||||
ttl: u32,
|
||||
original_ip: Ipv4Addr,
|
||||
forged_ip: Ipv4Addr,
|
||||
) {
|
||||
let forwarding = Forwarding {
|
||||
expires: Instant::now() + Duration::from_secs(ttl as u64),
|
||||
forged_ip: forged_ip,
|
||||
original_ip: original_ip,
|
||||
ttl: ttl,
|
||||
};
|
||||
|
||||
match forwarding_map.get_mut(dns_name) {
|
||||
Some(forwarding_list) => {
|
||||
forwarding_list.push(forwarding);
|
||||
}
|
||||
None => {
|
||||
forwarding_map.insert(dns_name.clone(), vec![forwarding]);
|
||||
}
|
||||
}
|
||||
|
||||
let mark_command = format!(
|
||||
"iptables -t mangle -A PREROUTING -d {} -j MARK --set-mark {}",
|
||||
forged_ip,
|
||||
FwmarkConfigMap.lock().await.get(dns_name).unwrap()
|
||||
);
|
||||
let _mark_output = run_command_string(mark_command).unwrap();
|
||||
|
||||
let dnat_command = format!(
|
||||
"iptables -t nat -A PREROUTING -d {} -j DNAT --to {}",
|
||||
forged_ip, original_ip
|
||||
);
|
||||
let _dnat_output = run_command_string(dnat_command).unwrap();
|
||||
}
|
||||
|
||||
async fn teardown_forwarding(
|
||||
forwarding_map: &mut HashMap<Vec<u8>, Vec<Forwarding>>,
|
||||
dns_name: &Vec<u8>,
|
||||
original_ip: Ipv4Addr,
|
||||
) {
|
||||
let forwarding_list = match forwarding_map.get_mut(dns_name) {
|
||||
Some(f) => f,
|
||||
None => return,
|
||||
};
|
||||
|
||||
let forwarding = match forwarding_list
|
||||
.iter()
|
||||
.find(|f| f.original_ip == original_ip)
|
||||
{
|
||||
Some(f) => f,
|
||||
None => return,
|
||||
}
|
||||
.clone();
|
||||
|
||||
forwarding_list.retain(|f| f.original_ip != original_ip);
|
||||
|
||||
if forwarding_list.len() == 0 {
|
||||
forwarding_map.remove(dns_name);
|
||||
}
|
||||
|
||||
IpAllocator.lock().await.release(forwarding.forged_ip);
|
||||
let mark_command = format!(
|
||||
"iptables -t mangle -D PREROUTING -d {} -j MARK --set-mark {}",
|
||||
forwarding.forged_ip,
|
||||
FwmarkConfigMap.lock().await.get(dns_name).unwrap()
|
||||
);
|
||||
let _mark_output = run_command_string(mark_command).unwrap();
|
||||
|
||||
let dnat_command = format!(
|
||||
"iptables -t nat -D PREROUTING -d {} -j DNAT --to {}",
|
||||
forwarding.forged_ip, forwarding.original_ip
|
||||
);
|
||||
let _dnat_output = run_command_string(dnat_command).unwrap();
|
||||
}
|
||||
|
||||
async fn query_upstream_resolvers(original_message: &[u8], tcp: bool) -> anyhow::Result<Vec<u8>> {
|
||||
let mut upstream_reply = Vec::new();
|
||||
|
||||
if tcp {
|
||||
let mut upstream = TcpStream::connect(UPSTREAM_DNS.get().unwrap()).await?;
|
||||
//upstream.write(&(original_message.len() as u16).to_be_bytes()).await?;
|
||||
upstream.write(&original_message).await?;
|
||||
let mut size_buffer = [0u8; 2];
|
||||
upstream.read(&mut size_buffer).await?;
|
||||
let size: u16 = u16::from_be_bytes(size_buffer);
|
||||
upstream_reply.resize(size as usize + 2, 0u8);
|
||||
upstream.read(&mut upstream_reply[2..]).await?;
|
||||
upstream_reply[0] = size_buffer[0];
|
||||
upstream_reply[1] = size_buffer[1];
|
||||
} else {
|
||||
let upstream = UdpSocket::bind("0.0.0.0:0").await?;
|
||||
upstream.connect(UPSTREAM_DNS.get().unwrap()).await?;
|
||||
upstream.send(original_message).await?;
|
||||
upstream_reply.resize(512, 0);
|
||||
upstream.recv_from(&mut upstream_reply[..512]).await?;
|
||||
}
|
||||
|
||||
return Ok(upstream_reply);
|
||||
}
|
||||
|
||||
fn forge_replies(replies: &Vec<Forwarding>, dns_name_string: String, qname_parts: Vec<Vec<u8>>, original_message: &[u8], reply_buf: &mut Vec<u8>, tcp: bool) {
|
||||
let reply: [u8; 12] = [
|
||||
0u8,
|
||||
0, // ID
|
||||
0b1000_0000,
|
||||
0b0000_0000, // Flags
|
||||
0,
|
||||
1, // Qdcount
|
||||
0,
|
||||
1, // Ancount
|
||||
0,
|
||||
0, // Nscount
|
||||
0,
|
||||
0, // Arcount
|
||||
];
|
||||
|
||||
let offset = if tcp { 2 } else { 0 };
|
||||
|
||||
let now = Instant::now();
|
||||
|
||||
println!(
|
||||
"Forging reply for {}, original IPs: {:?}, forged IPs: {:?}",
|
||||
dns_name_string,
|
||||
replies
|
||||
.iter()
|
||||
.map(|e| e.original_ip)
|
||||
.collect::<Vec<Ipv4Addr>>(),
|
||||
replies
|
||||
.iter()
|
||||
.map(|e| e.forged_ip)
|
||||
.collect::<Vec<Ipv4Addr>>()
|
||||
);
|
||||
|
||||
let mut new_reply = reply.clone().to_vec();
|
||||
new_reply[0..2].copy_from_slice(&original_message[offset..][0..=1]);
|
||||
new_reply[6..8].copy_from_slice(&(replies.len() as u16).to_be_bytes());
|
||||
new_reply.extend_from_slice(&original_message[offset..][12..=12 + qname_parts.iter().map(|p| p.len()+1).sum::<usize>() + 1 + 3]);
|
||||
for reply in replies {
|
||||
new_reply.extend_from_slice(&parts_to_dns_name(&qname_parts));
|
||||
new_reply.extend_from_slice(&[0, 1]);
|
||||
new_reply.extend_from_slice(&[0, 1]);
|
||||
new_reply.extend_from_slice(&((reply.expires - now).as_secs() as u32).to_be_bytes());
|
||||
new_reply.extend_from_slice(&[0, 4]);
|
||||
new_reply.extend_from_slice(&reply.forged_ip.octets());
|
||||
}
|
||||
if tcp {
|
||||
reply_buf.extend_from_slice(&(new_reply.len() as u16).to_be_bytes());
|
||||
}
|
||||
reply_buf.extend_from_slice(&new_reply);
|
||||
}
|
||||
|
||||
async fn handle_dns_response(buf: &[u8], reply_buf: &mut Vec<u8>, tcp: bool) -> anyhow::Result<()> {
|
||||
|
||||
let offset: usize = if tcp { 2 } else { 0 };
|
||||
let mut cursor = Cursor::from(&buf[offset..]);
|
||||
|
||||
// Identify some metadata from the query
|
||||
if cursor.seek(4).is_err() {
|
||||
return Err(anyhow::Error::msg("Failed to seek to QDCount"));
|
||||
}
|
||||
let qdcount = u16::from_be_bytes(cursor.next_array::<2>().ok_or(anyhow::Error::msg("Failed to read QDCount"))?);
|
||||
if qdcount != 1 {
|
||||
eprintln!("Got qdcount: {}", qdcount);
|
||||
return Err(anyhow::Error::msg(
|
||||
"Missing question from query. Got qdcount {}",
|
||||
));
|
||||
}
|
||||
if cursor.forward(6).is_err() {
|
||||
return Err(anyhow::Error::msg("Failed to seek to question section"));
|
||||
}
|
||||
let qname_parts =
|
||||
dns_name_to_parts(&mut cursor).ok_or(anyhow::Error::msg("Failed to decode QName."))?;
|
||||
let qtype = u16::from_be_bytes(cursor.next_array::<2>().ok_or(anyhow::Error::msg("Failed to read QType."))?);
|
||||
let qclass = u16::from_be_bytes(cursor.next_array::<2>().ok_or(anyhow::Error::msg("Failed to read QClass."))?);
|
||||
|
||||
// If the query is for anything other than qclass IN or qtype A, just forward the query upstream
|
||||
if qclass != 1 || qtype != 1 {
|
||||
let upstream_reply = query_upstream_resolvers(buf, tcp).await?;
|
||||
reply_buf.extend_from_slice(&upstream_reply);
|
||||
return Ok(())
|
||||
}
|
||||
|
||||
let now = Instant::now();
|
||||
|
||||
let dns_name = parts_to_dns_name(&qname_parts);
|
||||
let dns_name_string = dns_parts_to_string(&qname_parts);
|
||||
|
||||
let entries = match ForwardingMap.lock().await.get_mut(&dns_name) {
|
||||
Some(forwardings) => forwardings.clone(),
|
||||
None => Vec::new(),
|
||||
};
|
||||
|
||||
// Let's first lookup the qname in the Forwardings to see if we have non-expired answers
|
||||
if entries.len() > 0 && entries.iter().all(|e| e.expires > now) {
|
||||
forge_replies(&entries, dns_name_string, qname_parts, buf, reply_buf, tcp);
|
||||
return Ok(());
|
||||
} else {
|
||||
let upstream_reply = query_upstream_resolvers(buf, tcp).await?;
|
||||
|
||||
// Try an answer from the upstream response that has type A.
|
||||
let a_answers = match AnswerIterator::from(&upstream_reply[offset..]) {
|
||||
Some(answers) => answers
|
||||
.filter(|Answer { rrtype, .. }| *rrtype == 1)
|
||||
.collect::<Vec<Answer>>(),
|
||||
None => {
|
||||
return Err(anyhow::Error::msg(
|
||||
"Failed to extract answers from upstream reply!",
|
||||
));
|
||||
}
|
||||
};
|
||||
|
||||
let mut forge = true;
|
||||
if qtype != 1 {
|
||||
// Only forge for queries with an A qtype
|
||||
eprintln!("Not forging due non-A type question.");
|
||||
forge = false;
|
||||
}
|
||||
|
||||
if a_answers.len() == 0 {
|
||||
// If no A type answer, don't forge
|
||||
eprintln!("Not forging due to no returned A type answers.");
|
||||
forge = false;
|
||||
}
|
||||
|
||||
if a_answers.iter().any(|a| a.rdata.len() != 4) {
|
||||
eprintln!("Not forging due to malformed A type answer.",);
|
||||
forge = false;
|
||||
}
|
||||
|
||||
if !FwmarkConfigMap
|
||||
.lock()
|
||||
.await
|
||||
.contains_key(&parts_to_dns_name(&qname_parts))
|
||||
{
|
||||
eprintln!("Not forging due to non-matching qname.");
|
||||
forge = false;
|
||||
}
|
||||
|
||||
if forge {
|
||||
// Normalise a_answers so we're working with Ipv4Addr
|
||||
let normalised_answers: Vec<(Vec<Vec<u8>>, u32, Ipv4Addr)> = a_answers
|
||||
.iter()
|
||||
.map(|answer| {
|
||||
(
|
||||
answer.name.clone(),
|
||||
answer.ttl,
|
||||
Ipv4Addr::from(
|
||||
<Vec<u8> as TryInto<[u8; 4]>>::try_into(answer.rdata.clone()).unwrap(),
|
||||
),
|
||||
)
|
||||
})
|
||||
.collect::<_>();
|
||||
|
||||
let mut replies: Vec<Forwarding> = Vec::new();
|
||||
|
||||
// Determine if we need to create or renew our entries
|
||||
if entries.len() > 0 && entries.iter().all(|e| e.expires > now) {
|
||||
println!("Found not expired forwardings for {}", dns_name_string);
|
||||
replies.extend(entries.clone() as Vec<Forwarding>);
|
||||
} else {
|
||||
// We want to identify which of our A answers already exist in the ForwardingMap
|
||||
let existing_entries = entries
|
||||
.iter()
|
||||
.filter(|e| {
|
||||
normalised_answers
|
||||
.iter()
|
||||
.any(|(_, _, addr)| e.original_ip == *addr)
|
||||
})
|
||||
.collect::<Vec<_>>();
|
||||
|
||||
// Let's also identify which entries don't match any of the replies
|
||||
let nonexisting_entries = entries
|
||||
.iter()
|
||||
.filter(|e| {
|
||||
!normalised_answers
|
||||
.iter()
|
||||
.any(|(_, _, addr)| e.original_ip == *addr)
|
||||
})
|
||||
.collect::<Vec<_>>();
|
||||
|
||||
// And now let's find the replies that don't match any of the current entries
|
||||
let new_answers = normalised_answers
|
||||
.iter()
|
||||
.filter(|answer| entries.iter().all(|e| e.original_ip != answer.2))
|
||||
.collect::<Vec<_>>();
|
||||
|
||||
// Acquire the forwarding map
|
||||
let mut forwarding_map = ForwardingMap.lock().await;
|
||||
|
||||
// Remove all non-existing entries
|
||||
for entry in nonexisting_entries {
|
||||
println!(
|
||||
"Removed forwarding for {} with real IP {} / forged IP {} as the upstream no longer returns this value",
|
||||
dns_name_string, entry.original_ip, entry.forged_ip
|
||||
);
|
||||
teardown_forwarding(&mut forwarding_map, &dns_name, entry.original_ip).await;
|
||||
}
|
||||
|
||||
// Add new answers
|
||||
for (_, ttl, original_ip) in new_answers {
|
||||
let forged_ip = IpAllocator.lock().await.acquire().unwrap();
|
||||
setup_forwarding(
|
||||
&mut forwarding_map,
|
||||
&dns_name,
|
||||
*ttl,
|
||||
*original_ip,
|
||||
forged_ip,
|
||||
)
|
||||
.await;
|
||||
println!(
|
||||
"Added forwarding for {} with real IP {} / forged IP {}",
|
||||
dns_name_string, original_ip, forged_ip
|
||||
);
|
||||
}
|
||||
|
||||
// For all the existing entries, update their TTL if they're expired
|
||||
for entry in existing_entries.iter() {
|
||||
println!(
|
||||
"Updating TTL of existing entry {}/{}",
|
||||
dns_name_string, entry.forged_ip
|
||||
);
|
||||
if entry.expires < now {
|
||||
teardown_forwarding(&mut forwarding_map, &dns_name, entry.original_ip).await;
|
||||
setup_forwarding(
|
||||
&mut forwarding_map,
|
||||
&dns_name,
|
||||
entry.ttl,
|
||||
entry.original_ip,
|
||||
entry.forged_ip,
|
||||
)
|
||||
.await;
|
||||
}
|
||||
}
|
||||
|
||||
let entries = forwarding_map.get(&dns_name).unwrap().clone();
|
||||
replies.extend(entries);
|
||||
}
|
||||
|
||||
forge_replies(&replies, dns_name_string, qname_parts, buf, reply_buf, tcp);
|
||||
} else {
|
||||
reply_buf.extend_from_slice(&upstream_reply);
|
||||
}
|
||||
|
||||
Ok(())
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Clone, Debug)]
|
||||
struct Forwarding {
|
||||
expires: Instant,
|
||||
forged_ip: Ipv4Addr,
|
||||
original_ip: Ipv4Addr,
|
||||
ttl: u32,
|
||||
}
|
||||
|
||||
lazy_static! {
|
||||
static ref FwmarkConfigMap: Mutex<HashMap<Vec<u8>, u32>> = Mutex::new(HashMap::new());
|
||||
static ref ForwardingMap: Mutex<HashMap<Vec<u8>, Vec<Forwarding>>> = Mutex::new(HashMap::new());
|
||||
static ref IpAllocator: Mutex<IpPool> =
|
||||
Mutex::new(IpPool::new(Ipv4Addr::new(100, 64, 0, 0), 24).unwrap());
|
||||
}
|
||||
|
||||
static COMMAND_PREFIX: OnceLock<String> = OnceLock::new();
|
||||
static UPSTREAM_DNS: OnceLock<String> = OnceLock::new();
|
||||
static LISTEN_ADDR: OnceLock<String> = OnceLock::new();
|
||||
|
||||
async fn tcp_handler(listener: TcpListener) {
|
||||
loop {
|
||||
let accepted = listener.accept().await;
|
||||
if let Err(e) = &accepted {
|
||||
eprintln!("[TCP] Failed to accept TCP socket with error: {:?}", e);
|
||||
}
|
||||
|
||||
let (mut socket, addr) = accepted.unwrap();
|
||||
let mut size_buf = [0u8; 2];
|
||||
let Ok(2) = socket.read(&mut size_buf).await else {
|
||||
eprintln!("[TCP] Failed to read message size from {}", addr);
|
||||
continue;
|
||||
};
|
||||
let message_size = u16::from_be_bytes(size_buf) as usize;
|
||||
let mut message_buffer = Vec::new();
|
||||
message_buffer.resize(message_size + 2, 0);
|
||||
match socket.read(&mut message_buffer[2..]).await {
|
||||
Ok(n) => {
|
||||
if n != message_size {
|
||||
eprintln!(
|
||||
"[TCP] Received too few bytes than expected from {}. Message size indicated {}, read {}",
|
||||
addr, message_size, n
|
||||
);
|
||||
continue;
|
||||
}
|
||||
}
|
||||
Err(e) => {
|
||||
eprintln!(
|
||||
"[TCP] Failed to read message from {} with error: {:?}",
|
||||
addr, e
|
||||
);
|
||||
continue;
|
||||
}
|
||||
}
|
||||
message_buffer[0] = size_buf[0];
|
||||
message_buffer[1] = size_buf[1];
|
||||
let mut reply = Vec::new();
|
||||
if let Err(e) = handle_dns_response(&message_buffer, &mut reply, true).await {
|
||||
eprintln!(
|
||||
"[TCP] Received error when handling response for {}: {:?}",
|
||||
addr, e
|
||||
);
|
||||
continue;
|
||||
}
|
||||
if let Err(e) = socket.write(&reply).await {
|
||||
eprintln!(
|
||||
"[TCP] Received error when sending response to {}: {:?}",
|
||||
addr, e
|
||||
);
|
||||
continue;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
async fn udp_handler(socket: UdpSocket) {
|
||||
loop {
|
||||
let mut message_buffer = [0u8; 512];
|
||||
match socket.recv_from(&mut message_buffer).await {
|
||||
Ok((_, addr)) => {
|
||||
let mut reply = Vec::new();
|
||||
if let Err(e) = handle_dns_response(&message_buffer, &mut reply, false).await {
|
||||
eprintln!(
|
||||
"[UDP] Received error when handling response for {}: {:?}",
|
||||
addr, e
|
||||
);
|
||||
continue;
|
||||
}
|
||||
if let Err(e) = socket.send_to(&reply, addr).await {
|
||||
eprintln!(
|
||||
"[UDP] Received error when sending response to {}: {:?}",
|
||||
addr, e
|
||||
);
|
||||
continue;
|
||||
}
|
||||
}
|
||||
Err(e) => {
|
||||
eprintln!("[UDP] Failed to read message with error: {:?}", e);
|
||||
continue;
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
#[tokio::main]
|
||||
async fn main() -> anyhow::Result<()> {
|
||||
let mut args = std::env::args();
|
||||
let _ = args.next();
|
||||
let config_file_path = match args.next() {
|
||||
Some(p) => p,
|
||||
None => String::from("/etc/rust-dns-selective-routing/config.ini"),
|
||||
};
|
||||
let i = Ini::load_from_file(config_file_path).unwrap();
|
||||
|
||||
let map: HashMap<String, HashMap<String, String>> = i
|
||||
.into_iter()
|
||||
.filter_map(|(k, it)| match k {
|
||||
Some(s) => Some((s, it.into_iter().collect())),
|
||||
None => None,
|
||||
})
|
||||
.collect();
|
||||
|
||||
if let Some(v) = map.get("config") {
|
||||
if let Some(cmd) = v.get("command_prefix") {
|
||||
COMMAND_PREFIX.set(cmd.clone()).unwrap();
|
||||
} else {
|
||||
COMMAND_PREFIX.set(String::from("")).unwrap();
|
||||
}
|
||||
if let Some(upstream) = v.get("upstream_dns") {
|
||||
UPSTREAM_DNS.set(upstream.clone()).unwrap();
|
||||
} else {
|
||||
UPSTREAM_DNS.set(String::from("1.1.1.1:53")).unwrap();
|
||||
}
|
||||
if let Some(listen_addr) = v.get("listen_addr") {
|
||||
LISTEN_ADDR.set(listen_addr.clone()).unwrap();
|
||||
} else {
|
||||
LISTEN_ADDR.set(String::from("127.0.0.1:53")).unwrap();
|
||||
}
|
||||
}
|
||||
|
||||
if let Some(v) = map.get("domains") {
|
||||
let mut fwmark_config = FwmarkConfigMap.lock().await;
|
||||
for (domain, fwmark) in v.iter() {
|
||||
if let Ok(fwmark) = fwmark.parse::<u32>() {
|
||||
let _ = fwmark_config.insert(string_to_dns_name(domain.to_string()), fwmark);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
let recv_socket_tcp = TcpListener::bind(LISTEN_ADDR.get().unwrap()).await?;
|
||||
let recv_socket_udp = UdpSocket::bind(LISTEN_ADDR.get().unwrap()).await?;
|
||||
|
||||
tokio::try_join!(
|
||||
tokio::spawn(tcp_handler(recv_socket_tcp)),
|
||||
tokio::spawn(udp_handler(recv_socket_udp)),
|
||||
)?;
|
||||
Ok(())
|
||||
}
|
||||
Reference in New Issue
Block a user