From 7d738329131fba49986e44df1a872e7b01b7747b Mon Sep 17 00:00:00 2001 From: Xnoe Date: Thu, 14 May 2026 23:44:30 +0100 Subject: [PATCH] Other changes --- src/dns_parser.rs | 16 +-- src/main.rs | 314 +++++++++++++++++++++++----------------------- 2 files changed, 163 insertions(+), 167 deletions(-) diff --git a/src/dns_parser.rs b/src/dns_parser.rs index 9ba8272..8e3fd19 100644 --- a/src/dns_parser.rs +++ b/src/dns_parser.rs @@ -81,19 +81,19 @@ pub struct AnswerIterator<'a> { } impl<'a> AnswerIterator<'a> { - pub fn from(buf: &'a [u8]) -> Option { + pub fn from(buf: &'a [u8]) -> Result { let mut cursor = Cursor::from(buf); - cursor.seek(4).ok()?; - let qdcount = u16::from_be_bytes(cursor.next_array::<2>()?); - let ancount = u16::from_be_bytes(cursor.next_array::<2>()?); - cursor.forward(4).ok()?; + cursor.seek(4).or(Err(()))?; + let qdcount = u16::from_be_bytes(cursor.next_array::<2>().ok_or(())?); + let ancount = u16::from_be_bytes(cursor.next_array::<2>().ok_or(())?); + cursor.forward(4).or(Err(()))?; // Skip past the question section for _ in 0..qdcount { - dns_name_len(&mut cursor)?; - cursor.forward(4).ok()?; + dns_name_len(&mut cursor).ok_or(()); + cursor.forward(4).or(Err(()))?; } - Some(Self { cursor, ancount }) + Ok(Self { cursor, ancount }) } } diff --git a/src/main.rs b/src/main.rs index 327eba4..f6933d9 100644 --- a/src/main.rs +++ b/src/main.rs @@ -122,7 +122,7 @@ async fn teardown_forwarding( let _dnat_output = run_command_string(dnat_command).unwrap(); } -fn forge_replies(replies: &Vec, dns_name_string: String, qname_parts: Vec>, original_message: &[u8], reply_buf: &mut Vec) { +fn forge_replies(replies: &Vec, dns_name_string: String, qname_parts: Vec>, original_message: &[u8]) -> Vec { let reply: [u8; 12] = [ 0u8, 0, // ID @@ -165,15 +165,14 @@ fn forge_replies(replies: &Vec, dns_name_string: String, qname_parts new_reply.extend_from_slice(&[0, 4]); new_reply.extend_from_slice(&reply.forged_ip.octets()); } - reply_buf.extend_from_slice(&new_reply); + + new_reply } async fn handle_dns_response(mut handler: impl SelectiveRoutingHandler) -> anyhow::Result<()> { let buf = handler.get_original_message(); let mut cursor = Cursor::from(buf); - let mut reply_buf = Vec::new(); - // Identify some metadata from the query if cursor.seek(4).is_err() { return Err(anyhow::Error::msg("Failed to seek to QDCount")); @@ -182,9 +181,10 @@ async fn handle_dns_response(mut handler: impl SelectiveRoutingHandler) -> anyho let qdcount = u16::from_be_bytes(cursor.next_array::<2>().ok_or(anyhow::Error::msg("Failed to read QDCount"))?); if qdcount != 1 { eprintln!("Got qdcount: {}", qdcount); - return Err(anyhow::Error::msg( + return Err(anyhow::Error::msg(format!( "Missing question from query. Got qdcount {}", - )); + qdcount + ))); } if cursor.forward(6).is_err() { return Err(anyhow::Error::msg("Failed to seek to question section")); @@ -212,161 +212,157 @@ async fn handle_dns_response(mut handler: impl SelectiveRoutingHandler) -> anyho // Let's first lookup the qname in the Forwardings to see if we have non-expired answers if entries.len() > 0 && entries.iter().all(|e| e.expires > now) { - forge_replies(&entries, dns_name_string, qname_parts, buf, &mut reply_buf); - handler.reply(reply_buf).await.or(Err(anyhow::Error::msg("Failed to reply!")))?; + handler.reply( + forge_replies(&entries, dns_name_string, qname_parts, buf) + ).await.or(Err(anyhow::Error::msg("Failed to reply!")))?; return Ok(()); - } else { - let upstream_reply = handler.query_upstream().await.or(Err(anyhow::Error::msg("Failed to query upstream resolver.")))?; - - // Try an answer from the upstream response that has type A. - let a_answers = match AnswerIterator::from(&upstream_reply) { - Some(answers) => answers - .filter(|Answer { rrtype, .. }| *rrtype == 1) - .collect::>(), - None => { - return Err(anyhow::Error::msg( - "Failed to extract answers from upstream reply!", - )); - } - }; - - let mut forge = true; - if qtype != 1 { - // Only forge for queries with an A qtype - eprintln!("Not forging due non-A type question."); - forge = false; - } - - if a_answers.len() == 0 { - // If no A type answer, don't forge - eprintln!("Not forging due to no returned A type answers."); - forge = false; - } - - if a_answers.iter().any(|a| a.rdata.len() != 4) { - eprintln!("Not forging due to malformed A type answer.",); - forge = false; - } - - if !FwmarkConfigMap - .lock() - .await - .get(&parts_to_dns_name(&qname_parts)) - .is_some() - { - eprintln!("Not forging due to non-matching qname."); - forge = false; - } - - if forge { - // Normalise a_answers so we're working with Ipv4Addr - let normalised_answers: Vec<(Vec>, u32, Ipv4Addr)> = a_answers - .iter() - .map(|answer| { - ( - answer.name.clone(), - answer.ttl, - Ipv4Addr::from( - as TryInto<[u8; 4]>>::try_into(answer.rdata.clone()).unwrap(), - ), - ) - }) - .collect::<_>(); - - let mut replies: Vec = Vec::new(); - - // Determine if we need to create or renew our entries - if entries.len() > 0 && entries.iter().all(|e| e.expires > now) { - println!("Found not expired forwardings for {}", dns_name_string); - replies.extend(entries.clone() as Vec); - } else { - // We want to identify which of our A answers already exist in the ForwardingMap - let existing_entries = entries - .iter() - .filter(|e| { - normalised_answers - .iter() - .any(|(_, _, addr)| e.original_ip == *addr) - }) - .collect::>(); - - // Let's also identify which entries don't match any of the replies - let nonexisting_entries = entries - .iter() - .filter(|e| { - !normalised_answers - .iter() - .any(|(_, _, addr)| e.original_ip == *addr) - }) - .collect::>(); - - // And now let's find the replies that don't match any of the current entries - let new_answers = normalised_answers - .iter() - .filter(|answer| entries.iter().all(|e| e.original_ip != answer.2)) - .collect::>(); - - // Acquire the forwarding map - let mut forwarding_map = ForwardingMap.lock().await; - - // Remove all non-existing entries - for entry in nonexisting_entries { - println!( - "Removed forwarding for {} with real IP {} / forged IP {} as the upstream no longer returns this value", - dns_name_string, entry.original_ip, entry.forged_ip - ); - teardown_forwarding(&mut forwarding_map, &dns_name, entry.original_ip).await; - } - - // Add new answers - for (_, ttl, original_ip) in new_answers { - let forged_ip = IpAllocator.lock().await.acquire().unwrap(); - setup_forwarding( - &mut forwarding_map, - &dns_name, - *ttl, - *original_ip, - forged_ip, - ) - .await; - println!( - "Added forwarding for {} with real IP {} / forged IP {}", - dns_name_string, original_ip, forged_ip - ); - } - - // For all the existing entries, update their TTL if they're expired - for entry in existing_entries.iter() { - println!( - "Updating TTL of existing entry {}/{}", - dns_name_string, entry.forged_ip - ); - if entry.expires < now { - teardown_forwarding(&mut forwarding_map, &dns_name, entry.original_ip).await; - setup_forwarding( - &mut forwarding_map, - &dns_name, - entry.ttl, - entry.original_ip, - entry.forged_ip, - ) - .await; - } - } - - let entries = forwarding_map.get(&dns_name).unwrap().clone(); - replies.extend(entries); - } - - forge_replies(&replies, dns_name_string, qname_parts, buf, &mut reply_buf); - } else { - reply_buf.extend_from_slice(&upstream_reply); - } - - handler.reply(reply_buf).await.or(Err(anyhow::Error::msg("Failed to reply.")))?; - - Ok(()) } + + let upstream_reply = handler.query_upstream().await.or(Err(anyhow::Error::msg("Failed to query upstream resolver.")))?; + + // Try an answer from the upstream response that has type A. + let answer_iterator = AnswerIterator::from(&upstream_reply).or(Err(anyhow::Error::msg( + "Failed to extract answers from upstream reply!", + )))?; + let a_answers = answer_iterator + .filter(|Answer { rrtype, .. }| *rrtype == 1) + .collect::>(); + + let mut should_forge = true; + if qtype != 1 { + // Only forge for queries with an A qtype + eprintln!("Not forging due non-A type question."); + should_forge = false; + } + + if a_answers.len() == 0 { + // If no A type answer, don't forge + eprintln!("Not forging due to no returned A type answers."); + should_forge = false; + } + + if a_answers.iter().any(|a| a.rdata.len() != 4) { + eprintln!("Not forging due to malformed A type answer.",); + should_forge = false; + } + + if !FwmarkConfigMap + .lock() + .await + .get(&parts_to_dns_name(&qname_parts)) + .is_some() + { + eprintln!("Not forging due to non-matching qname."); + should_forge = false; + } + + if !should_forge { + handler.reply(upstream_reply).await.or(Err(anyhow::Error::msg("Failed to reply.")))?; + return Ok(()) + } + + // Normalise a_answers so we're working with Ipv4Addr + let normalised_answers: Vec<(Vec>, u32, Ipv4Addr)> = a_answers + .into_iter() + .map(|answer| { + ( + answer.name, + answer.ttl, + Ipv4Addr::from( + as TryInto<[u8; 4]>>::try_into(answer.rdata).unwrap(), + ), + ) + }) + .collect::<_>(); + + let mut replies: Vec = Vec::new(); + + // Determine if we need to create or renew our entries + if entries.len() > 0 && entries.iter().all(|e| e.expires > now) { + println!("Found not expired forwardings for {}", dns_name_string); + replies.extend(entries.clone() as Vec); + } else { + // We want to identify which of our A answers already exist in the ForwardingMap + let existing_entries = entries + .iter() + .filter(|e| { + normalised_answers + .iter() + .any(|(_, _, addr)| e.original_ip == *addr) + }) + .collect::>(); + + // Let's also identify which entries don't match any of the replies + let nonexisting_entries = entries + .iter() + .filter(|e| { + !normalised_answers + .iter() + .any(|(_, _, addr)| e.original_ip == *addr) + }) + .collect::>(); + + // And now let's find the replies that don't match any of the current entries + let new_answers = normalised_answers + .iter() + .filter(|answer| entries.iter().all(|e| e.original_ip != answer.2)) + .collect::>(); + + // Acquire the forwarding map + let mut forwarding_map = ForwardingMap.lock().await; + + // Remove all non-existing entries + for entry in nonexisting_entries { + println!( + "Removed forwarding for {} with real IP {} / forged IP {} as the upstream no longer returns this value", + dns_name_string, entry.original_ip, entry.forged_ip + ); + teardown_forwarding(&mut forwarding_map, &dns_name, entry.original_ip).await; + } + + // Add new answers + for (_, ttl, original_ip) in new_answers { + let forged_ip = IpAllocator.lock().await.acquire().unwrap(); + setup_forwarding( + &mut forwarding_map, + &dns_name, + *ttl, + *original_ip, + forged_ip, + ) + .await; + println!( + "Added forwarding for {} with real IP {} / forged IP {}", + dns_name_string, original_ip, forged_ip + ); + } + + // For all the existing entries, update their TTL if they're expired + for entry in existing_entries.iter() { + println!( + "Updating TTL of existing entry {}/{}", + dns_name_string, entry.forged_ip + ); + if entry.expires < now { + teardown_forwarding(&mut forwarding_map, &dns_name, entry.original_ip).await; + setup_forwarding( + &mut forwarding_map, + &dns_name, + entry.ttl, + entry.original_ip, + entry.forged_ip, + ) + .await; + } + } + + let entries = forwarding_map.get(&dns_name).unwrap().clone(); + replies.extend(entries); + } + + handler.reply(forge_replies(&replies, dns_name_string, qname_parts, buf)).await.or(Err(anyhow::Error::msg("Failed to reply.")))?; + + Ok(()) } #[derive(Clone, Debug)]