Squash commits.

This commit is contained in:
2025-12-05 22:03:11 +00:00
commit 79fee04aef
7 changed files with 1261 additions and 0 deletions
+60
View File
@@ -0,0 +1,60 @@
#[derive(Clone)]
pub struct Cursor<'a, T> {
buf: &'a [T],
index: usize,
}
impl<'a, T> Cursor<'a, T> {
pub fn from(buf: &'a [T]) -> Self {
Self { buf, index: 0 }
}
pub fn index(&mut self) -> usize {
self.index
}
pub fn next(&mut self) -> Option<&T> {
let next_index = self.index + 1;
if next_index >= self.buf.len() {
None
} else {
let v = &self.buf[self.index];
self.index = next_index;
Some(v)
}
}
pub fn seek(&mut self, location: usize) -> Result<(), ()> {
if location >= self.buf.len() {
Err(())
} else {
self.index = location;
Ok(())
}
}
pub fn next_slice(&mut self, amount: usize) -> Option<&'a [T]> {
let next_index = self.index + amount;
if next_index >= self.buf.len() {
None
} else {
let slice = &self.buf[self.index..next_index];
self.index = next_index;
Some(slice)
}
}
pub fn next_array<const N: usize>(&mut self) -> Option<[T; N]> where [T; N]: TryFrom<&'a [T]> {
Some(self.next_slice(N)?.try_into().ok()?)
}
pub fn forward(&mut self, amount: usize) -> Result<(), ()> {
let next_index = self.index + amount;
if next_index >= self.buf.len() {
Err(())
} else {
self.index = next_index;
Ok(())
}
}
}
+131
View File
@@ -0,0 +1,131 @@
use crate::cursor::Cursor;
pub fn dns_name_to_parts(cursor: &mut Cursor<u8>) -> Option<Vec<Vec<u8>>> {
let mut parts: Vec<Vec<u8>> = Vec::new();
let mut ptr_depth = 0;
let mut restore_point = None;
while let Some(&(mut byte)) = cursor.next() {
if byte == 0 {
break;
}
if byte >= 192 {
if let None = restore_point {
restore_point = Some(cursor.index() + 1);
}
if ptr_depth >= 16 {
return None;
}
let ptr_lsb = *cursor.next()?;
cursor
.seek(u16::from_be_bytes([byte & 0b0011_1111, ptr_lsb]) as usize)
.ok()?;
byte = *cursor.next()?;
ptr_depth += 1;
}
parts.push(cursor.next_slice(byte as usize)?.to_vec());
}
if let Some(position) = restore_point {
cursor.seek(position).ok()?;
}
Some(parts)
}
pub fn dns_name_len(buffer: &mut Cursor<u8>) -> Option<usize> {
let mut length = 0;
while let Some(&byte) = buffer.next() {
if byte == 0 {
return Some(length + 1);
}
if byte >= 192 {
return Some(length + 2);
}
buffer.forward(byte as usize).ok()?;
length += byte as usize + 1;
}
Some(length)
}
pub fn parts_to_dns_name(parts: &Vec<Vec<u8>>) -> Vec<u8> {
let mut result = Vec::new();
for part in parts {
result.push(part.len() as u8);
result.extend_from_slice(&part);
}
result.push(0u8);
result
}
pub fn dns_parts_to_string(parts: &Vec<Vec<u8>>) -> String {
let mut result = String::new();
for part in parts {
result.push_str(&String::from_utf8_lossy(&part));
result.push('.');
}
return result;
}
pub fn string_to_dns_name(string: String) -> Vec<u8> {
let mut result = Vec::new();
for part in string.split('.') {
result.push(part.len() as u8);
result.extend_from_slice(part.as_bytes());
}
result.push(0u8);
return result;
}
pub struct AnswerIterator<'a> {
cursor: Cursor<'a, u8>,
ancount: u16,
}
impl<'a> AnswerIterator<'a> {
pub fn from(buf: &'a [u8]) -> Option<Self> {
let mut cursor = Cursor::from(buf);
cursor.seek(4).ok()?;
let qdcount = u16::from_be_bytes(cursor.next_array::<2>()?);
let ancount = u16::from_be_bytes(cursor.next_array::<2>()?);
cursor.forward(4).ok()?;
// Skip past the question section
for _ in 0..qdcount {
dns_name_len(&mut cursor)?;
cursor.forward(4).ok()?;
}
Some(Self { cursor, ancount })
}
}
#[derive(PartialEq, Debug)]
pub struct Answer {
pub name: Vec<Vec<u8>>,
pub rrtype: u16,
pub rrclass: u16,
pub ttl: u32,
pub rdata: Vec<u8>,
}
impl<'a> Iterator for AnswerIterator<'a> {
type Item = Answer;
fn next(&mut self) -> Option<Self::Item> {
if self.ancount == 0 {
return None;
}
self.ancount -= 1;
let name = dns_name_to_parts(&mut self.cursor)?;
let rrtype = u16::from_be_bytes(self.cursor.next_array::<2>()?);
let rrclass = u16::from_be_bytes(self.cursor.next_array::<2>()?);
let ttl = u32::from_be_bytes(self.cursor.next_array::<4>()?);
let rdlength = u16::from_be_bytes(self.cursor.next_array::<2>()?) as usize;
let rdata = self.cursor.next_slice(rdlength)?.to_vec();
Some(Answer {
name,
rrtype,
rrclass,
ttl,
rdata,
})
}
}
+47
View File
@@ -0,0 +1,47 @@
use std::collections::{HashSet, VecDeque};
use std::net::Ipv4Addr;
pub struct IpPool {
free_ips: VecDeque<Ipv4Addr>,
allocated_ips: HashSet<Ipv4Addr>,
}
impl IpPool {
pub fn new(base_addr: Ipv4Addr, subnet_prefix_size: u8) -> Option<Self> {
if subnet_prefix_size == 32 {
return None;
}
let subnet_prefix_mask: u32 = u32::MAX << (32 - subnet_prefix_size);
let base_addr_int = <Ipv4Addr as Into<u32>>::into(base_addr);
if base_addr_int & !subnet_prefix_mask != 0 {
return None;
}
let last_address_number = u32::MAX & !subnet_prefix_mask;
let mut pool = VecDeque::with_capacity(last_address_number as usize);
for number in 0..=(last_address_number as usize) {
pool.push_back(Ipv4Addr::from(base_addr_int + (number as u32)));
}
Some(Self {
free_ips: pool,
allocated_ips: HashSet::new(),
})
}
pub fn acquire(&mut self) -> Option<Ipv4Addr> {
match self.free_ips.pop_front() {
ip @ Some(addr) => {
self.allocated_ips.insert(addr);
ip
}
None => None,
}
}
pub fn release(&mut self, addr: Ipv4Addr) {
if self.allocated_ips.contains(&addr) {
self.allocated_ips.remove(&addr);
self.free_ips.push_back(addr);
}
}
}
+551
View File
@@ -0,0 +1,551 @@
#[macro_use]
extern crate lazy_static;
mod cursor;
mod dns_parser;
mod ip_pool;
use std::collections::HashMap;
use std::net::Ipv4Addr;
use std::process::{Command, Output};
use std::sync::OnceLock;
use std::time::{Duration, Instant};
use ini::Ini;
use tokio::{
io::{AsyncReadExt, AsyncWriteExt},
net::{TcpListener, TcpStream, UdpSocket},
sync::Mutex,
};
use crate::cursor::Cursor;
use crate::dns_parser::{
Answer, AnswerIterator, dns_name_to_parts, dns_parts_to_string,
parts_to_dns_name, string_to_dns_name,
};
use crate::ip_pool::IpPool;
fn run_command_string(command: String) -> Result<Output, std::io::Error> {
let command_vec = {
if COMMAND_PREFIX.get().unwrap().eq("") {
command.split(' ').collect::<Vec<&str>>()
} else {
COMMAND_PREFIX
.get()
.unwrap()
.split(' ')
.chain(command.split(' '))
.collect::<Vec<&str>>()
}
};
Command::new(command_vec[0])
.args(&command_vec[1..])
.output()
}
async fn setup_forwarding(
forwarding_map: &mut HashMap<Vec<u8>, Vec<Forwarding>>,
dns_name: &Vec<u8>,
ttl: u32,
original_ip: Ipv4Addr,
forged_ip: Ipv4Addr,
) {
let forwarding = Forwarding {
expires: Instant::now() + Duration::from_secs(ttl as u64),
forged_ip: forged_ip,
original_ip: original_ip,
ttl: ttl,
};
match forwarding_map.get_mut(dns_name) {
Some(forwarding_list) => {
forwarding_list.push(forwarding);
}
None => {
forwarding_map.insert(dns_name.clone(), vec![forwarding]);
}
}
let mark_command = format!(
"iptables -t mangle -A PREROUTING -d {} -j MARK --set-mark {}",
forged_ip,
FwmarkConfigMap.lock().await.get(dns_name).unwrap()
);
let _mark_output = run_command_string(mark_command).unwrap();
let dnat_command = format!(
"iptables -t nat -A PREROUTING -d {} -j DNAT --to {}",
forged_ip, original_ip
);
let _dnat_output = run_command_string(dnat_command).unwrap();
}
async fn teardown_forwarding(
forwarding_map: &mut HashMap<Vec<u8>, Vec<Forwarding>>,
dns_name: &Vec<u8>,
original_ip: Ipv4Addr,
) {
let forwarding_list = match forwarding_map.get_mut(dns_name) {
Some(f) => f,
None => return,
};
let forwarding = match forwarding_list
.iter()
.find(|f| f.original_ip == original_ip)
{
Some(f) => f,
None => return,
}
.clone();
forwarding_list.retain(|f| f.original_ip != original_ip);
if forwarding_list.len() == 0 {
forwarding_map.remove(dns_name);
}
IpAllocator.lock().await.release(forwarding.forged_ip);
let mark_command = format!(
"iptables -t mangle -D PREROUTING -d {} -j MARK --set-mark {}",
forwarding.forged_ip,
FwmarkConfigMap.lock().await.get(dns_name).unwrap()
);
let _mark_output = run_command_string(mark_command).unwrap();
let dnat_command = format!(
"iptables -t nat -D PREROUTING -d {} -j DNAT --to {}",
forwarding.forged_ip, forwarding.original_ip
);
let _dnat_output = run_command_string(dnat_command).unwrap();
}
async fn query_upstream_resolvers(original_message: &[u8], tcp: bool) -> anyhow::Result<Vec<u8>> {
let mut upstream_reply = Vec::new();
if tcp {
let mut upstream = TcpStream::connect(UPSTREAM_DNS.get().unwrap()).await?;
//upstream.write(&(original_message.len() as u16).to_be_bytes()).await?;
upstream.write(&original_message).await?;
let mut size_buffer = [0u8; 2];
upstream.read(&mut size_buffer).await?;
let size: u16 = u16::from_be_bytes(size_buffer);
upstream_reply.resize(size as usize + 2, 0u8);
upstream.read(&mut upstream_reply[2..]).await?;
upstream_reply[0] = size_buffer[0];
upstream_reply[1] = size_buffer[1];
} else {
let upstream = UdpSocket::bind("0.0.0.0:0").await?;
upstream.connect(UPSTREAM_DNS.get().unwrap()).await?;
upstream.send(original_message).await?;
upstream_reply.resize(512, 0);
upstream.recv_from(&mut upstream_reply[..512]).await?;
}
return Ok(upstream_reply);
}
fn forge_replies(replies: &Vec<Forwarding>, dns_name_string: String, qname_parts: Vec<Vec<u8>>, original_message: &[u8], reply_buf: &mut Vec<u8>, tcp: bool) {
let reply: [u8; 12] = [
0u8,
0, // ID
0b1000_0000,
0b0000_0000, // Flags
0,
1, // Qdcount
0,
1, // Ancount
0,
0, // Nscount
0,
0, // Arcount
];
let offset = if tcp { 2 } else { 0 };
let now = Instant::now();
println!(
"Forging reply for {}, original IPs: {:?}, forged IPs: {:?}",
dns_name_string,
replies
.iter()
.map(|e| e.original_ip)
.collect::<Vec<Ipv4Addr>>(),
replies
.iter()
.map(|e| e.forged_ip)
.collect::<Vec<Ipv4Addr>>()
);
let mut new_reply = reply.clone().to_vec();
new_reply[0..2].copy_from_slice(&original_message[offset..][0..=1]);
new_reply[6..8].copy_from_slice(&(replies.len() as u16).to_be_bytes());
new_reply.extend_from_slice(&original_message[offset..][12..=12 + qname_parts.iter().map(|p| p.len()+1).sum::<usize>() + 1 + 3]);
for reply in replies {
new_reply.extend_from_slice(&parts_to_dns_name(&qname_parts));
new_reply.extend_from_slice(&[0, 1]);
new_reply.extend_from_slice(&[0, 1]);
new_reply.extend_from_slice(&((reply.expires - now).as_secs() as u32).to_be_bytes());
new_reply.extend_from_slice(&[0, 4]);
new_reply.extend_from_slice(&reply.forged_ip.octets());
}
if tcp {
reply_buf.extend_from_slice(&(new_reply.len() as u16).to_be_bytes());
}
reply_buf.extend_from_slice(&new_reply);
}
async fn handle_dns_response(buf: &[u8], reply_buf: &mut Vec<u8>, tcp: bool) -> anyhow::Result<()> {
let offset: usize = if tcp { 2 } else { 0 };
let mut cursor = Cursor::from(&buf[offset..]);
// Identify some metadata from the query
if cursor.seek(4).is_err() {
return Err(anyhow::Error::msg("Failed to seek to QDCount"));
}
let qdcount = u16::from_be_bytes(cursor.next_array::<2>().ok_or(anyhow::Error::msg("Failed to read QDCount"))?);
if qdcount != 1 {
eprintln!("Got qdcount: {}", qdcount);
return Err(anyhow::Error::msg(
"Missing question from query. Got qdcount {}",
));
}
if cursor.forward(6).is_err() {
return Err(anyhow::Error::msg("Failed to seek to question section"));
}
let qname_parts =
dns_name_to_parts(&mut cursor).ok_or(anyhow::Error::msg("Failed to decode QName."))?;
let qtype = u16::from_be_bytes(cursor.next_array::<2>().ok_or(anyhow::Error::msg("Failed to read QType."))?);
let qclass = u16::from_be_bytes(cursor.next_array::<2>().ok_or(anyhow::Error::msg("Failed to read QClass."))?);
// If the query is for anything other than qclass IN or qtype A, just forward the query upstream
if qclass != 1 || qtype != 1 {
let upstream_reply = query_upstream_resolvers(buf, tcp).await?;
reply_buf.extend_from_slice(&upstream_reply);
return Ok(())
}
let now = Instant::now();
let dns_name = parts_to_dns_name(&qname_parts);
let dns_name_string = dns_parts_to_string(&qname_parts);
let entries = match ForwardingMap.lock().await.get_mut(&dns_name) {
Some(forwardings) => forwardings.clone(),
None => Vec::new(),
};
// Let's first lookup the qname in the Forwardings to see if we have non-expired answers
if entries.len() > 0 && entries.iter().all(|e| e.expires > now) {
forge_replies(&entries, dns_name_string, qname_parts, buf, reply_buf, tcp);
return Ok(());
} else {
let upstream_reply = query_upstream_resolvers(buf, tcp).await?;
// Try an answer from the upstream response that has type A.
let a_answers = match AnswerIterator::from(&upstream_reply[offset..]) {
Some(answers) => answers
.filter(|Answer { rrtype, .. }| *rrtype == 1)
.collect::<Vec<Answer>>(),
None => {
return Err(anyhow::Error::msg(
"Failed to extract answers from upstream reply!",
));
}
};
let mut forge = true;
if qtype != 1 {
// Only forge for queries with an A qtype
eprintln!("Not forging due non-A type question.");
forge = false;
}
if a_answers.len() == 0 {
// If no A type answer, don't forge
eprintln!("Not forging due to no returned A type answers.");
forge = false;
}
if a_answers.iter().any(|a| a.rdata.len() != 4) {
eprintln!("Not forging due to malformed A type answer.",);
forge = false;
}
if !FwmarkConfigMap
.lock()
.await
.contains_key(&parts_to_dns_name(&qname_parts))
{
eprintln!("Not forging due to non-matching qname.");
forge = false;
}
if forge {
// Normalise a_answers so we're working with Ipv4Addr
let normalised_answers: Vec<(Vec<Vec<u8>>, u32, Ipv4Addr)> = a_answers
.iter()
.map(|answer| {
(
answer.name.clone(),
answer.ttl,
Ipv4Addr::from(
<Vec<u8> as TryInto<[u8; 4]>>::try_into(answer.rdata.clone()).unwrap(),
),
)
})
.collect::<_>();
let mut replies: Vec<Forwarding> = Vec::new();
// Determine if we need to create or renew our entries
if entries.len() > 0 && entries.iter().all(|e| e.expires > now) {
println!("Found not expired forwardings for {}", dns_name_string);
replies.extend(entries.clone() as Vec<Forwarding>);
} else {
// We want to identify which of our A answers already exist in the ForwardingMap
let existing_entries = entries
.iter()
.filter(|e| {
normalised_answers
.iter()
.any(|(_, _, addr)| e.original_ip == *addr)
})
.collect::<Vec<_>>();
// Let's also identify which entries don't match any of the replies
let nonexisting_entries = entries
.iter()
.filter(|e| {
!normalised_answers
.iter()
.any(|(_, _, addr)| e.original_ip == *addr)
})
.collect::<Vec<_>>();
// And now let's find the replies that don't match any of the current entries
let new_answers = normalised_answers
.iter()
.filter(|answer| entries.iter().all(|e| e.original_ip != answer.2))
.collect::<Vec<_>>();
// Acquire the forwarding map
let mut forwarding_map = ForwardingMap.lock().await;
// Remove all non-existing entries
for entry in nonexisting_entries {
println!(
"Removed forwarding for {} with real IP {} / forged IP {} as the upstream no longer returns this value",
dns_name_string, entry.original_ip, entry.forged_ip
);
teardown_forwarding(&mut forwarding_map, &dns_name, entry.original_ip).await;
}
// Add new answers
for (_, ttl, original_ip) in new_answers {
let forged_ip = IpAllocator.lock().await.acquire().unwrap();
setup_forwarding(
&mut forwarding_map,
&dns_name,
*ttl,
*original_ip,
forged_ip,
)
.await;
println!(
"Added forwarding for {} with real IP {} / forged IP {}",
dns_name_string, original_ip, forged_ip
);
}
// For all the existing entries, update their TTL if they're expired
for entry in existing_entries.iter() {
println!(
"Updating TTL of existing entry {}/{}",
dns_name_string, entry.forged_ip
);
if entry.expires < now {
teardown_forwarding(&mut forwarding_map, &dns_name, entry.original_ip).await;
setup_forwarding(
&mut forwarding_map,
&dns_name,
entry.ttl,
entry.original_ip,
entry.forged_ip,
)
.await;
}
}
let entries = forwarding_map.get(&dns_name).unwrap().clone();
replies.extend(entries);
}
forge_replies(&replies, dns_name_string, qname_parts, buf, reply_buf, tcp);
} else {
reply_buf.extend_from_slice(&upstream_reply);
}
Ok(())
}
}
#[derive(Clone, Debug)]
struct Forwarding {
expires: Instant,
forged_ip: Ipv4Addr,
original_ip: Ipv4Addr,
ttl: u32,
}
lazy_static! {
static ref FwmarkConfigMap: Mutex<HashMap<Vec<u8>, u32>> = Mutex::new(HashMap::new());
static ref ForwardingMap: Mutex<HashMap<Vec<u8>, Vec<Forwarding>>> = Mutex::new(HashMap::new());
static ref IpAllocator: Mutex<IpPool> =
Mutex::new(IpPool::new(Ipv4Addr::new(100, 64, 0, 0), 24).unwrap());
}
static COMMAND_PREFIX: OnceLock<String> = OnceLock::new();
static UPSTREAM_DNS: OnceLock<String> = OnceLock::new();
static LISTEN_ADDR: OnceLock<String> = OnceLock::new();
async fn tcp_handler(listener: TcpListener) {
loop {
let accepted = listener.accept().await;
if let Err(e) = &accepted {
eprintln!("[TCP] Failed to accept TCP socket with error: {:?}", e);
}
let (mut socket, addr) = accepted.unwrap();
let mut size_buf = [0u8; 2];
let Ok(2) = socket.read(&mut size_buf).await else {
eprintln!("[TCP] Failed to read message size from {}", addr);
continue;
};
let message_size = u16::from_be_bytes(size_buf) as usize;
let mut message_buffer = Vec::new();
message_buffer.resize(message_size + 2, 0);
match socket.read(&mut message_buffer[2..]).await {
Ok(n) => {
if n != message_size {
eprintln!(
"[TCP] Received too few bytes than expected from {}. Message size indicated {}, read {}",
addr, message_size, n
);
continue;
}
}
Err(e) => {
eprintln!(
"[TCP] Failed to read message from {} with error: {:?}",
addr, e
);
continue;
}
}
message_buffer[0] = size_buf[0];
message_buffer[1] = size_buf[1];
let mut reply = Vec::new();
if let Err(e) = handle_dns_response(&message_buffer, &mut reply, true).await {
eprintln!(
"[TCP] Received error when handling response for {}: {:?}",
addr, e
);
continue;
}
if let Err(e) = socket.write(&reply).await {
eprintln!(
"[TCP] Received error when sending response to {}: {:?}",
addr, e
);
continue;
}
}
}
async fn udp_handler(socket: UdpSocket) {
loop {
let mut message_buffer = [0u8; 512];
match socket.recv_from(&mut message_buffer).await {
Ok((_, addr)) => {
let mut reply = Vec::new();
if let Err(e) = handle_dns_response(&message_buffer, &mut reply, false).await {
eprintln!(
"[UDP] Received error when handling response for {}: {:?}",
addr, e
);
continue;
}
if let Err(e) = socket.send_to(&reply, addr).await {
eprintln!(
"[UDP] Received error when sending response to {}: {:?}",
addr, e
);
continue;
}
}
Err(e) => {
eprintln!("[UDP] Failed to read message with error: {:?}", e);
continue;
}
}
}
}
#[tokio::main]
async fn main() -> anyhow::Result<()> {
let mut args = std::env::args();
let _ = args.next();
let config_file_path = match args.next() {
Some(p) => p,
None => String::from("/etc/rust-dns-selective-routing/config.ini"),
};
let i = Ini::load_from_file(config_file_path).unwrap();
let map: HashMap<String, HashMap<String, String>> = i
.into_iter()
.filter_map(|(k, it)| match k {
Some(s) => Some((s, it.into_iter().collect())),
None => None,
})
.collect();
if let Some(v) = map.get("config") {
if let Some(cmd) = v.get("command_prefix") {
COMMAND_PREFIX.set(cmd.clone()).unwrap();
} else {
COMMAND_PREFIX.set(String::from("")).unwrap();
}
if let Some(upstream) = v.get("upstream_dns") {
UPSTREAM_DNS.set(upstream.clone()).unwrap();
} else {
UPSTREAM_DNS.set(String::from("1.1.1.1:53")).unwrap();
}
if let Some(listen_addr) = v.get("listen_addr") {
LISTEN_ADDR.set(listen_addr.clone()).unwrap();
} else {
LISTEN_ADDR.set(String::from("127.0.0.1:53")).unwrap();
}
}
if let Some(v) = map.get("domains") {
let mut fwmark_config = FwmarkConfigMap.lock().await;
for (domain, fwmark) in v.iter() {
if let Ok(fwmark) = fwmark.parse::<u32>() {
let _ = fwmark_config.insert(string_to_dns_name(domain.to_string()), fwmark);
}
}
}
let recv_socket_tcp = TcpListener::bind(LISTEN_ADDR.get().unwrap()).await?;
let recv_socket_udp = UdpSocket::bind(LISTEN_ADDR.get().unwrap()).await?;
tokio::try_join!(
tokio::spawn(tcp_handler(recv_socket_tcp)),
tokio::spawn(udp_handler(recv_socket_udp)),
)?;
Ok(())
}