From 855b09e580e10826ed15d98c9c7e30dbb363183a Mon Sep 17 00:00:00 2001 From: Xnoe Date: Sat, 29 Oct 2022 19:53:33 +0100 Subject: [PATCH] Actually functional DHCP client now. --- Cargo.lock | 42 ------- Cargo.toml | 1 - src/main.rs | 319 ++++++++++++++++++++++++++++++++++++----------- src/rawsocket.rs | 133 ++++++++++++++++++++ src/rtnetlink.rs | 247 ++++++++++++++++++++++-------------- 5 files changed, 529 insertions(+), 213 deletions(-) create mode 100644 src/rawsocket.rs diff --git a/Cargo.lock b/Cargo.lock index 1479831..d795051 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -11,18 +11,6 @@ dependencies = [ "memchr", ] -[[package]] -name = "autocfg" -version = "1.1.0" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "d468802bab17cbc0cc575e9b053f41e72aa36bfa6b7f55e3529ffa43161b97fa" - -[[package]] -name = "bitflags" -version = "1.3.2" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "bef38d45163c2f1dde094a7dfd33ccf595c92905c8f8f4fdc18d06fb1037718a" - [[package]] name = "cfg-if" version = "1.0.0" @@ -69,35 +57,6 @@ version = "2.5.0" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "2dffe52ecf27772e601905b7522cb4ef790d2cc203488bbd0e2fe85fcb74566d" -[[package]] -name = "memoffset" -version = "0.6.5" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "5aa361d4faea93603064a027415f07bd8e1d5c88c9fbf68bf56a285428fd79ce" -dependencies = [ - "autocfg", -] - -[[package]] -name = "nix" -version = "0.25.0" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "e322c04a9e3440c327fca7b6c8a63e6890a32fa2ad689db972425f07e0d22abb" -dependencies = [ - "autocfg", - "bitflags", - "cfg-if", - "libc", - "memoffset", - "pin-utils", -] - -[[package]] -name = "pin-utils" -version = "0.1.0" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "8b870d8c151b6f2fb93e84a13146138f05d02ed11c7e7c54f8826aaaf7c9f184" - [[package]] name = "ppv-lite86" version = "0.2.16" @@ -158,7 +117,6 @@ dependencies = [ "dhcprs", "eui48", "libc", - "nix", "rand", ] diff --git a/Cargo.toml b/Cargo.toml index 94b03ef..c55a9c2 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -7,7 +7,6 @@ edition = "2021" [dependencies] dhcprs = {version = "0.1.3", path = "../rust-dhcprs"} -nix = "0.25.0" libc = "0.2.133" eui48 = "1.1.0" rand = "0.8.5" \ No newline at end of file diff --git a/src/main.rs b/src/main.rs index 1fb36d2..2dfbf41 100644 --- a/src/main.rs +++ b/src/main.rs @@ -1,7 +1,7 @@ pub mod rtnetlink; +pub mod rawsocket; use eui48::MacAddress; -use nix::sys::socket::*; use rand::prelude::*; use std::net::Ipv4Addr; @@ -10,6 +10,8 @@ use dhcprs::dhcp::DHCPOption; use std::collections::HashMap; +use std::io::Write; + fn create_dhcp_packet( xid: u32, mac: MacAddress, @@ -56,44 +58,40 @@ enum DHCPTransactionState { WaitingAfterDiscover, Request, WaitAfterRequest, - Renew, - WaitAfterRenew, - Rebind, - WaitafterRebind, + Renew +} + +fn pick_weighted(list: &Vec) -> Option<&T> { + let len = list.len(); + let mut rng = rand::thread_rng(); + let mut prob = rng.gen_range(0f64..=1f64); + + println!("start prob={}", prob); + for (idx, elem) in list.iter().enumerate() { + let weight = 1f64 / (2 as usize).pow(idx as u32 + (len != idx + 1) as u32) as f64; + + prob -= weight; + + println!("len={} idx={}: weight={}, prob={}", len, idx, weight, prob); + + if prob <= 0.0 { + return Some(elem); + } + + } + + // fallback in case of float rounding errors i suppose + list.choose(&mut rng) } fn dhcp_client(name: String, index: i32, mac: [u8; 6]) { + unsafe { libc::sleep(5); } + let mut name = name; + let mut socket_nl = rtnetlink::create_netlink_socket(false).unwrap(); let mut rng = rand::thread_rng(); // Before we can do anything else, we need to construct an LinkAddr for the mac ff:ff:ff:ff:ff:ff - let sockaddr_ll = libc::sockaddr_ll { - sll_family: libc::AF_PACKET as u16, - sll_halen: 6, - sll_hatype: 1, - sll_ifindex: index, - sll_addr: [0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0, 0], - sll_pkttype: 0, - sll_protocol: 0x0008, - }; - - let mut linkaddr = unsafe { - LinkAddr::from_raw( - &sockaddr_ll as *const libc::sockaddr_ll as *const libc::sockaddr, - Some(20), - ) - .expect("Failed to create linkaddr!") - }; - - // Create and bind the socket - let socket = socket( - AddressFamily::Packet, - SockType::Datagram, - SockFlag::empty(), - SockProtocol::Udp, - ) - .expect("Failed to create socket! Permission issue?"); - assert!(setsockopt(socket, sockopt::Broadcast, &true).is_ok()); - assert!(bind(socket, &linkaddr).is_ok()); + let mut socket = rawsocket::create_raw_socket(index, mac).unwrap(); let client_mac = MacAddress::from_bytes(&mac).unwrap(); @@ -109,10 +107,9 @@ fn dhcp_client(name: String, index: i32, mac: [u8; 6]) { 'dhcp_message_loop: loop { match &dhcp_state { DHCPTransactionState::Discover => { - println!("Sent DHCPDiscover on {}", name); loop { - match sendto( - socket, + println!("Sent DHCPDiscover on {}", name); + match socket.send( create_dhcp_packet( xid, client_mac, @@ -123,15 +120,65 @@ fn dhcp_client(name: String, index: i32, mac: [u8; 6]) { DHCPOption::End, ], ) - .as_bytes(), - &linkaddr, - MsgFlags::empty(), + .as_bytes() ) { Ok(_) => break, - Err(nix::errno::Errno::ENETDOWN) => { - unsafe { libc::sleep(3) }; + Err(libc::ENETDOWN) => { + unsafe { + { + #[repr(packed)] + struct Request {a: libc::nlmsghdr, b: rtnetlink::ifinfomsg} + let mut request = std::mem::zeroed::(); + + request.a.nlmsg_len = std::mem::size_of::() as u32; + request.a.nlmsg_type = libc::RTM_NEWLINK; + request.a.nlmsg_flags = (libc::NLM_F_REQUEST) as u16; + request.b.ifi_index = index; + request.b.ifi_flags = libc::IFF_UP as u32; + request.b.ifi_change = 1; + + socket_nl.send(rtnetlink::to_slice(&request)).unwrap(); + } + + { + #[repr(packed)] + struct Request {a: libc::nlmsghdr, b: rtnetlink::ifinfomsg} + let mut request = std::mem::zeroed::(); + + request.a.nlmsg_len = std::mem::size_of::() as u32; + request.a.nlmsg_type = libc::RTM_GETLINK; + request.a.nlmsg_flags = (libc::NLM_F_REQUEST) as u16; + request.b.ifi_index = index; + + socket_nl.send(rtnetlink::to_slice(&request)).unwrap(); + + let mut buffer = [0; 4096]; + socket_nl.recv(&mut buffer).unwrap(); + + let nlmsghdr = &buffer as *const _ as *const libc::nlmsghdr; + if (*nlmsghdr).nlmsg_type != libc::RTM_NEWLINK { + panic!("Wrong message!"); + } + + let mut n = (*nlmsghdr).nlmsg_len - std::mem::size_of::() as u32 - std::mem::size_of::() as u32; + let mut rtattr = (&buffer as *const _ as *const u8).offset(std::mem::size_of::() as isize + std::mem::size_of::() as isize) as *const rtnetlink::rtattr; + + while rtnetlink::rta_ok(rtattr, &mut n) { + match (*rtattr).rta_type { + libc::IFLA_IFNAME => { + name = zascii(&std::ptr::read(rtnetlink::rta_data(rtattr) as *const _ as *const [libc::c_char; libc::IFNAMSIZ])); + } + _ => () + } + + rtattr = rtnetlink::rta_next(rtattr, &mut n); + } + } + + libc::sleep(10); + } } - Err(_) => panic!("") + Err(_) => panic!("Failed to send on {}", name) } } dhcp_state = DHCPTransactionState::WaitingAfterDiscover; @@ -139,9 +186,7 @@ fn dhcp_client(name: String, index: i32, mac: [u8; 6]) { DHCPTransactionState::WaitingAfterDiscover => { let mut packet: [u8; 574] = [0; 574]; - let (bytes, addr) = recvfrom::(socket, &mut packet).unwrap(); - - println!("Received {} bytes on {}", bytes, name); + let (_, addr) = socket.recv(&mut packet).unwrap(); let udppacket_raw: dhcprs::udpbuilder::RawUDPPacket = unsafe { std::ptr::read(packet.as_ptr() as *const _) }; @@ -194,33 +239,14 @@ fn dhcp_client(name: String, index: i32, mac: [u8; 6]) { ); // Update the linkaddr to be the server's actual hardware address rather than the broadcast MAC. - let mut mac: [u8; 8] = [0; 8]; - mac[..6].copy_from_slice(&addr.unwrap().addr().unwrap()); - let new_sockaddr_ll = libc::sockaddr_ll { - sll_family: sockaddr_ll.sll_family, - sll_halen: sockaddr_ll.sll_halen, - sll_hatype: sockaddr_ll.sll_hatype, - sll_ifindex: sockaddr_ll.sll_ifindex, - sll_addr: mac, - sll_pkttype: sockaddr_ll.sll_pkttype, - sll_protocol: sockaddr_ll.sll_protocol, - }; - - linkaddr = unsafe { - LinkAddr::from_raw( - &new_sockaddr_ll as *const libc::sockaddr_ll as *const libc::sockaddr, - Some(20), - ) - .expect("Failed to update linkaddr to server's hardware address!") - }; + socket.set_destination_to(addr); dhcp_state = DHCPTransactionState::Request; } DHCPTransactionState::Request => { println!("Sent DHCPRequest on {}", name); - let _ = sendto( - socket, + socket.send( create_dhcp_packet( xid, client_mac, @@ -234,17 +260,14 @@ fn dhcp_client(name: String, index: i32, mac: [u8; 6]) { DHCPOption::End, ], ) - .as_bytes(), - &linkaddr, - MsgFlags::empty(), - ); + .as_bytes() + ).unwrap(); dhcp_state = DHCPTransactionState::WaitAfterRequest; } DHCPTransactionState::WaitAfterRequest => { let mut packet: [u8; 574] = [0; 574]; - let (bytes, _) = recvfrom::(socket, &mut packet).unwrap(); - println!("Received {} bytes on {}", bytes, name); + socket.recv(&mut packet).unwrap(); let udppacket_raw: dhcprs::udpbuilder::RawUDPPacket = unsafe { std::ptr::read(packet.as_ptr() as *const _) }; @@ -285,22 +308,168 @@ fn dhcp_client(name: String, index: i32, mac: [u8; 6]) { } let mut sleep_time = 0; + let mut subnet_mask = Ipv4Addr::new(0,0,0,0); + let mut router = Vec::new(); + let mut dns = Vec::new(); + let mut classless_static_routes = Vec::new(); for option in options { println!("Received: {:?}", option); match option { DHCPOption::IPAddressLeaseTime(n) => sleep_time = n, + DHCPOption::SubnetMask(m) => subnet_mask = m, + DHCPOption::Router(v) => router = v, + DHCPOption::DomainNameServer(v) => dns = v, + DHCPOption::ClasslessStaticRoute(v) => classless_static_routes = v, _ => (), } } - println!("Sleeping for {} for lease time to elapse.", sleep_time); - std::thread::sleep(core::time::Duration::new(sleep_time.into(), 0)); + unsafe { + { + #[repr(packed)] + struct Request {a: libc::nlmsghdr, b: rtnetlink::ifaddrmsg, c: rtnetlink::rtattr, d: [u8; 4], e: rtnetlink::rtattr, f: [u8; 4], g: rtnetlink::rtattr, h: rtnetlink::ifa_cacheinfo} + let mut request = std::mem::zeroed::(); + + request.a.nlmsg_len = std::mem::size_of::() as u32; + request.a.nlmsg_type = libc::RTM_NEWADDR; + request.a.nlmsg_flags = (libc::NLM_F_REQUEST | libc::NLM_F_CREATE | libc::NLM_F_REPLACE) as u16; + request.b.ifa_family = libc::AF_INET as u8; + request.b.ifa_index = index; + request.b.ifa_prefixlen = u32::from(subnet_mask).leading_ones() as u8; + request.c.rta_type = libc::IFA_LOCAL; + request.c.rta_len = std::mem::size_of::<(rtnetlink::rtattr, [u8; 4])>() as u16; + request.d = client_addr.unwrap().octets(); + request.e.rta_type = libc::IFA_BROADCAST; + request.e.rta_len = std::mem::size_of::<(rtnetlink::rtattr, [u8; 4])>() as u16; + request.f = ((u32::from(client_addr.unwrap()) & u32::from(subnet_mask)) | (!u32::from(subnet_mask))).to_be_bytes(); + request.g.rta_type = libc::IFA_CACHEINFO; + request.g.rta_len = std::mem::size_of::<(rtnetlink::rtattr, rtnetlink::ifa_cacheinfo)>() as u16; + request.h.ifa_preferred = (sleep_time as f32 * 0.5) as u32; + request.h.ifa_valid = (sleep_time as f32 * 0.875) as u32; + + socket_nl.send(rtnetlink::to_slice(&request)).unwrap(); + } + + { + #[repr(packed)] + struct Request {a: libc::nlmsghdr, b: rtnetlink::rtmsg, c: rtnetlink::rtattr, d: [u8; 4], e: rtnetlink::rtattr, f: libc::c_int, g: rtnetlink::rtattr, h: rtnetlink::rta_cacheinfo} + + let mut request = std::mem::zeroed::(); + + request.a.nlmsg_len = std::mem::size_of::() as u32; + request.a.nlmsg_type = libc::RTM_NEWROUTE; + request.a.nlmsg_flags = (libc::NLM_F_REQUEST | libc::NLM_F_CREATE) as u16; + request.b.rtm_family = libc::AF_INET as u8; + request.b.rtm_table = libc::RT_TABLE_MAIN; + request.b.rtm_type = libc::RTN_UNICAST; + request.b.rtm_protocol = 16; // DHCP + request.b.rtm_scope = libc::RT_SCOPE_LINK; + request.b.rtm_dst_len = u32::from(subnet_mask).leading_ones() as u8; + request.c.rta_type = libc::RTA_DST; + request.c.rta_len = std::mem::size_of::<(rtnetlink::rtattr, [u8; 4])>() as u16; + request.d = client_addr.unwrap().octets(); + request.e.rta_type = libc::RTA_OIF; + request.e.rta_len = std::mem::size_of::<(rtnetlink::rtattr, libc::c_int)>() as u16; + request.f = index; + request.g.rta_type = libc::RTA_CACHEINFO; + request.g.rta_len = std::mem::size_of::<(rtnetlink::rtattr, rtnetlink::rta_cacheinfo)>() as u16; + request.h.rta_expires = (sleep_time as f32 * 0.875) as u32; + + socket_nl.send(rtnetlink::to_slice(&request)).unwrap(); + } + + { + #[repr(packed)] + struct Request {a: libc::nlmsghdr, b: rtnetlink::rtmsg, c: rtnetlink::rtattr, d: [u8; 4], e: rtnetlink::rtattr, f: libc::c_int, g: rtnetlink::rtattr, h: rtnetlink::rta_cacheinfo} + + let mut request = std::mem::zeroed::(); + + request.a.nlmsg_len = std::mem::size_of::() as u32; + request.a.nlmsg_type = libc::RTM_NEWROUTE; + request.a.nlmsg_flags = (libc::NLM_F_REQUEST | libc::NLM_F_CREATE) as u16; + request.b.rtm_family = libc::AF_INET as u8; + request.b.rtm_table = libc::RT_TABLE_MAIN; + request.b.rtm_type = libc::RTN_UNICAST; + request.b.rtm_protocol = 16; // DHCP + request.b.rtm_scope = libc::RT_SCOPE_UNIVERSE; + request.b.rtm_dst_len = 0; + request.c.rta_type = libc::RTA_GATEWAY; + request.c.rta_len = std::mem::size_of::<(rtnetlink::rtattr, [u8; 4])>() as u16; + request.d = pick_weighted(&router).unwrap().octets(); + request.e.rta_type = libc::RTA_OIF; + request.e.rta_len = std::mem::size_of::<(rtnetlink::rtattr, libc::c_int)>() as u16; + request.f = index; + request.g.rta_type = libc::RTA_CACHEINFO; + request.g.rta_len = std::mem::size_of::<(rtnetlink::rtattr, rtnetlink::rta_cacheinfo)>() as u16; + request.h.rta_expires = (sleep_time as f32 * 0.875) as u32; + + socket_nl.send(rtnetlink::to_slice(&request)).unwrap(); + } + + { + #[repr(packed)] + struct Request {a: libc::nlmsghdr, b: rtnetlink::rtmsg, c: rtnetlink::rtattr, d: [u8; 4], e: rtnetlink::rtattr, f: [u8; 4], g: rtnetlink::rtattr, h: libc::c_int} + + for (prefix, prefix_len, router) in classless_static_routes { + let mut request = std::mem::zeroed::(); + + request.a.nlmsg_len = std::mem::size_of::() as u32; + request.a.nlmsg_type = libc::RTM_NEWROUTE; + request.a.nlmsg_flags = (libc::NLM_F_REQUEST | libc::NLM_F_CREATE) as u16; + request.b.rtm_family = libc::AF_INET as u8; + request.b.rtm_table = libc::RT_TABLE_MAIN; + request.b.rtm_type = libc::RTN_UNICAST; + request.b.rtm_protocol = 16; // DHCP + request.b.rtm_scope = libc::RT_SCOPE_UNIVERSE; + request.b.rtm_dst_len = prefix_len; + request.c.rta_type = libc::RTA_DST; + request.c.rta_len = std::mem::size_of::<(rtnetlink::rtattr, [u8; 4])>() as u16; + request.d = prefix.octets(); + request.e.rta_type = libc::RTA_GATEWAY; + request.e.rta_len = std::mem::size_of::<(rtnetlink::rtattr, [u8; 4])>() as u16; + request.f = router.octets(); + request.g.rta_type = libc::RTA_OIF; + request.g.rta_len = std::mem::size_of::<(rtnetlink::rtattr, libc::c_int)>() as u16; + request.h = index; + + + socket_nl.send(rtnetlink::to_slice(&request)).unwrap(); + } + } + + let mut f = std::fs::OpenOptions::new().write(true).truncate(true).open("/etc/resolv.conf").unwrap(); + for server in dns { + f.write_all(("nameserver ".to_string() + &server.to_string() + "\n").as_bytes()).unwrap(); + } + f.flush().unwrap(); + } + + println!("Sleeping for {} for lease time to elapse.", (sleep_time as f32 * 0.5) as u32); + std::thread::sleep(core::time::Duration::new((sleep_time as f32 * 0.5) as u64, 0)); + dhcp_state = DHCPTransactionState::Renew } - _ => panic!("Fail!"), + DHCPTransactionState::Renew => { + println!("[Renew] Sent DHCPRequest on {}", name); + let _ = socket.send( + create_dhcp_packet( + xid, + client_mac, + client_addr, + None, + vec![ + DHCPOption::DHCPMessageType(DHCPMessageType::DHCPRequest), + DHCPOption::ParameterRequest(vec![1, 3, 6, 28, 121]), + DHCPOption::End, + ], + ) + .as_bytes() + ); + dhcp_state = DHCPTransactionState::WaitAfterRequest; + } } } } diff --git a/src/rawsocket.rs b/src/rawsocket.rs new file mode 100644 index 0000000..70744de --- /dev/null +++ b/src/rawsocket.rs @@ -0,0 +1,133 @@ +fn pad_array(src: &[T]) -> [T; N] { + let mut arr = [Default::default(); N]; + arr[..src.len()].copy_from_slice(src); + arr +} + +pub struct RawSocket { + socket: libc::c_int, + src_addr: libc::sockaddr_ll, + dest_addr: libc::sockaddr_ll, + if_index: i32 +} + +pub fn create_raw_socket(index: i32, mac: [u8; 6]) -> Result { + let src_addr = libc::sockaddr_ll { + sll_family: libc::AF_PACKET as u16, + sll_halen: 6, + sll_hatype: 1, + sll_ifindex: index, + sll_addr: pad_array(&mac), + sll_pkttype: 0, + sll_protocol: 0x0008, + }; + + let dest_addr = libc::sockaddr_ll { + sll_family: libc::AF_PACKET as u16, + sll_halen: 6, + sll_hatype: 1, + sll_ifindex: index, + sll_addr: [0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0, 0], + sll_pkttype: 0, + sll_protocol: 0x0008, + }; + + let socket; + unsafe { + socket = libc::socket(libc::AF_PACKET, libc::SOCK_DGRAM, libc::IPPROTO_UDP); + if socket == -1 { + return Err("Failed to create socket"); + } + + let mut timeval = std::mem::zeroed::(); + timeval.tv_sec = 30; + + if libc::setsockopt(socket, libc::SOL_SOCKET, libc::SO_RCVTIMEO, &timeval as *const _ as *const libc::c_void, std::mem::size_of::() as u32) == -1 { + return Err("Failed to set SO_RCVTIMEO"); + } + + if libc::setsockopt(socket, libc::SOL_SOCKET, libc::SO_BROADCAST, &(1 as u32) as *const _ as *const libc::c_void, std::mem::size_of::() as u32) == -1 { + return Err("Failed to set SO_BROADCAST") + } + + if libc::bind(socket, &src_addr as *const _ as *const libc::sockaddr, std::mem::size_of::() as u32) == -1 { + return Err("Failed to bind to socket"); + } + } + + Ok(RawSocket { + socket: socket, + src_addr: src_addr, + dest_addr: dest_addr, + if_index: index + }) +} + +impl RawSocket { + pub fn send(&mut self, msg: &[u8]) -> Result { + let mut iov = libc::iovec { + iov_base: msg as *const _ as *mut libc::c_void, + iov_len: msg.len() as usize, + }; + + let mut msg = unsafe { std::mem::zeroed::() }; + msg.msg_name = &mut self.dest_addr as *mut _ as *mut libc::c_void; + msg.msg_namelen = std::mem::size_of::() as u32; + msg.msg_iov = &mut iov as *mut libc::iovec; + msg.msg_iovlen = 1; + + let count = unsafe { libc::sendmsg(self.socket, &msg as *const libc::msghdr, 0) }; + if count < 0 { + return Err(unsafe { *libc::__errno_location() }); + } + + Ok(count) + } + + pub fn recv(&mut self, buffer: &mut [u8; N]) -> Result<(isize, [u8; 6]), i32> { + let mut iov = libc::iovec { + iov_base: buffer as *mut _ as *mut libc::c_void, + iov_len: N, + }; + + let mut msg = unsafe { std::mem::zeroed::() }; + msg.msg_name = &mut self.src_addr as *mut _ as *mut libc::c_void; + msg.msg_namelen = std::mem::size_of::() as u32; + msg.msg_iov = &mut iov as *mut libc::iovec; + msg.msg_iovlen = 1; + + let n = unsafe { libc::recvmsg(self.socket, &mut msg as *mut libc::msghdr, 0) }; + if n < 0 { + match unsafe { *libc::__errno_location() } { + libc::EAGAIN => return Ok((0, [0, 0, 0, 0, 0, 0])), + err => return Err(err) + } + } else { + return Ok((n, unsafe { std::ptr::read(msg.msg_name as *mut _ as *mut libc::sockaddr_ll) }.sll_addr[..6].try_into().unwrap())); + } + } + + pub fn set_destination_to_broadcast(&mut self) { + self.dest_addr = libc::sockaddr_ll { + sll_family: libc::AF_PACKET as u16, + sll_halen: 6, + sll_hatype: 1, + sll_ifindex: self.if_index, + sll_addr: [0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0, 0], + sll_pkttype: 0, + sll_protocol: 0x0008, + }; + } + + pub fn set_destination_to(&mut self, addr: [u8; 6]) { + self.dest_addr = libc::sockaddr_ll { + sll_family: libc::AF_PACKET as u16, + sll_halen: 6, + sll_hatype: 1, + sll_ifindex: self.if_index, + sll_addr: pad_array(&addr), + sll_pkttype: 0, + sll_protocol: 0x0008, + }; + } +} \ No newline at end of file diff --git a/src/rtnetlink.rs b/src/rtnetlink.rs index 341b54a..482da5d 100644 --- a/src/rtnetlink.rs +++ b/src/rtnetlink.rs @@ -1,15 +1,58 @@ -use eui48::MacAddress; - #[repr(C)] -struct ifinfomsg { - ifi_family: libc::__u8, - ifi_pad: libc::__u8, - ifi_type: libc::__u16, - ifi_index: libc::c_int, - ifi_flags: libc::__u32, - ifi_change: libc::__u32, +pub struct ifinfomsg { + pub ifi_family: libc::__u8, + pub ifi_pad: libc::__u8, + pub ifi_type: libc::__u16, + pub ifi_index: libc::c_int, + pub ifi_flags: libc::__u32, + pub ifi_change: libc::__u32, } +#[repr(C)] +pub struct ifaddrmsg { + pub ifa_family: libc::c_uchar, + pub ifa_prefixlen: libc::c_uchar, + pub ifa_flags: libc::c_uchar, + pub ifa_scope: libc::c_uchar, + pub ifa_index: libc::c_int, +} + +#[repr(C)] +pub struct rtmsg { + pub rtm_family: libc::c_uchar, + pub rtm_dst_len: libc::c_uchar, + pub rtm_src_len: libc::c_uchar, + pub rtm_tos: libc::c_uchar, + pub rtm_table: libc::c_uchar, + + pub rtm_protocol: libc::c_uchar, + pub rtm_scope: libc::c_uchar, + pub rtm_type: libc::c_uchar, + + pub rtm_flags: libc::c_uint, +} + +#[repr(C)] +pub struct ifa_cacheinfo { + pub ifa_preferred: libc::__u32, + pub ifa_valid: libc::__u32, + pub cstamp: libc::__u32, + pub tstamp: libc::__u32, +} + +#[repr(C)] +pub struct rta_cacheinfo { + pub rta_clntref: libc::__u32, + pub rta_lastuse: libc::__u32, + pub rta_expires: libc::__u32, + pub rta_error: libc::__u32, + pub rta_used: libc::__u32, + pub rta_id: libc::__u32, + pub rta_ts: libc::__u32, + pub rta_tsage: libc::__u32, +} + +#[allow(non_upper_case_globals)] const nlmsg_alignto: u32 = 4; macro_rules! nlmsg_align { @@ -18,22 +61,15 @@ macro_rules! nlmsg_align { }; } +#[allow(non_upper_case_globals)] const nlmsg_hdrlen: usize = nlmsg_align!(std::mem::size_of::()) as usize; -macro_rules! nlmsg_length { - ($len:expr) => { - $len as u32 + nlmsg_hdrlen as u32 - }; -} - -macro_rules! nlmsg_space { - ($len:expr) => { - nlmsg_align!(nlmsg_length!($len)) - }; +pub fn nlmsg_length(len: u32) -> u32 { + len as u32 + nlmsg_hdrlen as u32 } #[inline] -unsafe fn nlmsg_data(nlh: *const libc::nlmsghdr) -> *const u8 { +pub unsafe fn nlmsg_data(nlh: *const libc::nlmsghdr) -> *const u8 { return (nlh as *const u8).offset(nlmsg_hdrlen as isize); } @@ -51,17 +87,13 @@ unsafe fn nlmsg_ok(nlh: *const libc::nlmsghdr, len: &mut u32) -> bool { && (*nlh).nlmsg_len <= *len; } -#[inline] -unsafe fn nlmsg_payload(nlh: *const libc::nlmsghdr, len: &mut u32) -> u32 { - return (*nlh).nlmsg_len - nlmsg_space!(*len); -} - #[repr(C)] -struct rtattr { - rta_len: libc::c_ushort, - rta_type: libc::c_ushort, +pub struct rtattr { + pub rta_len: libc::c_ushort, + pub rta_type: libc::c_ushort, } +#[allow(non_upper_case_globals)] const rta_alignto: u32 = 4; macro_rules! rta_align { @@ -71,14 +103,14 @@ macro_rules! rta_align { } #[inline] -unsafe fn rta_ok(rta: *const rtattr, len: &mut u32) -> bool { +pub unsafe fn rta_ok(rta: *const rtattr, len: &mut u32) -> bool { return *len >= std::mem::size_of::() as u32 && (*rta).rta_len >= std::mem::size_of::() as libc::c_ushort && (*rta).rta_len <= *len as libc::c_ushort; } #[inline] -unsafe fn rta_next(rta: *const rtattr, len: &mut u32) -> *const rtattr { +pub unsafe fn rta_next(rta: *const rtattr, len: &mut u32) -> *const rtattr { *len -= rta_align!((*rta).rta_len); return (rta as *const u8).offset(nlmsg_align!((*rta).rta_len) as isize) as *const rtattr; } @@ -89,12 +121,7 @@ fn rta_length(len: u32) -> u32 { } #[inline] -fn rta_space(len: u32) -> u32 { - rta_align!(rta_length(len)) -} - -#[inline] -unsafe fn rta_data(rta: *const rtattr) -> *const u8 { +pub unsafe fn rta_data(rta: *const rtattr) -> *const u8 { return (rta as *const u8).offset(rta_length(0) as isize); } @@ -125,71 +152,112 @@ unsafe fn ifla_rta(r: *const ifinfomsg) -> *const rtattr { as *const rtattr; } -pub struct InterfaceIterator { +pub struct Socket { socket: libc::c_int, src_addr: libc::sockaddr_nl, - dest_addr: libc::sockaddr_nl, - subscribed: bool, + dest_addr: libc::sockaddr_nl +} + +pub fn create_netlink_socket(subscribed: bool) -> Result { + let mut src_addr = unsafe { std::mem::zeroed::() }; + let mut dest_addr = unsafe { std::mem::zeroed::() }; + + src_addr.nl_family = libc::AF_NETLINK as u16; + if subscribed { + src_addr.nl_groups = libc::RTMGRP_LINK as u32; + } + src_addr.nl_pid = unsafe { libc::getpid() } as u32; + + dest_addr.nl_family = libc::AF_NETLINK as u16; + + let socket; + unsafe { + socket = libc::socket(libc::AF_NETLINK, libc::SOCK_RAW, libc::NETLINK_ROUTE); + if socket == -1 { + return Err("Failed to create socket"); + } + + if libc::bind(socket, &src_addr as *const _ as *const libc::sockaddr, std::mem::size_of::() as u32) == -1 { + return Err("Failed to bind to socket"); + } + } + + Ok(Socket { + socket: socket, + src_addr: src_addr, + dest_addr: dest_addr + }) +} + +impl Socket { + pub fn send(&mut self, msg: &[u8]) -> Result<(), i32> { + let mut iov = libc::iovec { + iov_base: msg as *const _ as *mut libc::c_void, + iov_len: msg.len() as usize, + }; + + let mut msg = unsafe { std::mem::zeroed::() }; + msg.msg_name = &mut self.dest_addr as *mut _ as *mut libc::c_void; + msg.msg_namelen = std::mem::size_of::() as u32; + msg.msg_iov = &mut iov as *mut libc::iovec; + msg.msg_iovlen = 1; + + if unsafe { libc::sendmsg(self.socket, &msg as *const libc::msghdr, 0) } < 0 { + return Err(unsafe { *libc::__errno_location() }); + } + + Ok(()) + } + + pub fn recv(&mut self, buffer: &mut [u8; N]) -> Result { + let mut iov = libc::iovec { + iov_base: buffer as *mut _ as *mut libc::c_void, + iov_len: N, + }; + + let mut msg = unsafe { std::mem::zeroed::() }; + msg.msg_name = &mut self.src_addr as *mut _ as *mut libc::c_void; + msg.msg_namelen = std::mem::size_of::() as u32; + msg.msg_iov = &mut iov as *mut libc::iovec; + msg.msg_iovlen = 1; + + let n = unsafe { libc::recvmsg(self.socket, &mut msg as *mut libc::msghdr, 0) }; + if n < 0 { + return Err(n as i32); + } else { + return Ok(n); + } + } +} + +pub struct InterfaceIterator { + socket: Socket, buffer: [u8; 4096], current_header: Option<*const libc::nlmsghdr>, current_len: u32, } -pub fn new_interface_iterator() -> Result { - let socket; +pub unsafe fn to_slice(p: &T) -> &[u8] { + std::slice::from_raw_parts( + (p as *const T) as *const u8, + std::mem::size_of::(), + ) +} - let mut src_addr = unsafe { std::mem::zeroed::() }; - let mut dest_addr = unsafe { std::mem::zeroed::() }; - - src_addr.nl_family = libc::AF_NETLINK as u16; - src_addr.nl_groups = libc::RTMGRP_LINK as u32; - src_addr.nl_pid = unsafe { libc::getpid() } as u32; - - dest_addr.nl_family = libc::AF_NETLINK as u16; - - unsafe { - socket = libc::socket(libc::AF_NETLINK, libc::SOCK_RAW, libc::NETLINK_ROUTE); - if socket == -1 { - return Err("Failed to create AF_NETLINK socket!".to_string()); - } - - if libc::bind( - socket, - &src_addr as *const libc::sockaddr_nl as *const libc::sockaddr, - std::mem::size_of::() as u32, - ) == -1 - { - return Err("Failed to bind to AF_NETLINK socket!".to_string()); - } - } +pub fn new_interface_iterator() -> Result { + let mut socket = create_netlink_socket(true).unwrap(); let mut request = unsafe { std::mem::zeroed::<(libc::nlmsghdr, ifinfomsg)>() }; - request.0.nlmsg_len = nlmsg_length!(std::mem::size_of::()); + request.0.nlmsg_len = nlmsg_length(std::mem::size_of::() as u32); request.0.nlmsg_type = libc::RTM_GETLINK; request.0.nlmsg_flags = (libc::NLM_F_REQUEST | libc::NLM_F_DUMP) as u16; request.1.ifi_family = libc::AF_NETLINK as u8; - let mut iov = libc::iovec { - iov_base: &mut request as *mut _ as *mut libc::c_void, - iov_len: request.0.nlmsg_len as usize, - }; - - let mut msg = unsafe { std::mem::zeroed::() }; - msg.msg_name = &mut dest_addr as *mut _ as *mut libc::c_void; - msg.msg_namelen = std::mem::size_of::() as u32; - msg.msg_iov = &mut iov as *mut libc::iovec; - msg.msg_iovlen = 1; - - if unsafe { libc::sendmsg(socket, &msg as *const libc::msghdr, 0) } < 0 { - return Err("Failed to send interface request!".to_string()); - } + socket.send(unsafe {to_slice(&request)}).expect("Failed to send interface request!"); Ok(InterfaceIterator { socket: socket, - src_addr: src_addr, - dest_addr: dest_addr, - subscribed: true, buffer: [0; 4096], current_header: None, @@ -213,18 +281,7 @@ impl Iterator for InterfaceIterator { fn next(&mut self) -> Option { loop { if let None = self.current_header { - let mut recv_iov = libc::iovec { - iov_base: &mut self.buffer as *mut _ as *mut libc::c_void, - iov_len: 4096, - }; - - let mut msg = unsafe { std::mem::zeroed::() }; - msg.msg_name = &mut self.dest_addr as *mut _ as *mut libc::c_void; - msg.msg_namelen = std::mem::size_of::() as u32; - msg.msg_iov = &mut recv_iov as *mut libc::iovec; - msg.msg_iovlen = 1; - - let n = unsafe { libc::recvmsg(self.socket, &mut msg as *mut libc::msghdr, 0) }; + let n = self.socket.recv(&mut self.buffer).unwrap(); if n < 0 { panic!("Failed to receive on AF_NETLINK socket!"); }