diff --git a/Apple/NetworkExtension/PacketTunnelProvider.swift b/Apple/NetworkExtension/PacketTunnelProvider.swift index 8260aa0..e9c48dd 100644 --- a/Apple/NetworkExtension/PacketTunnelProvider.swift +++ b/Apple/NetworkExtension/PacketTunnelProvider.swift @@ -6,7 +6,7 @@ class PacketTunnelProvider: NEPacketTunnelProvider { let logger = Logger(subsystem: "com.hackclub.burrow", category: "frontend") var client: BurrowIpc? var osInitialized = false - override func startTunnel(options: [String: NSObject]?, completionHandler: @escaping (Error?) -> Void) { + override func startTunnel(options: [String : NSObject]? = nil) async throws { logger.log("Starting tunnel") if !osInitialized { libburrow.initialize_oslog() @@ -15,38 +15,35 @@ class PacketTunnelProvider: NEPacketTunnelProvider { libburrow.start_srv() client = BurrowIpc(logger: logger) logger.info("Started server") - Task { - do { - let command = BurrowSingleCommand(id: 0, command: "ServerConfig") - guard let data = try await client?.request(command, type: Response>.self) - else { - throw BurrowError.cantParseResult - } - let encoded = try JSONEncoder().encode(data.result) - self.logger.log("Received final data: \(String(decoding: encoded, as: UTF8.self))") - guard let serverconfig = data.result.Ok else { - throw BurrowError.resultIsError - } - guard let tunNs = self.generateTunSettings(from: serverconfig) else { - throw BurrowError.addrDoesntExist - } - try await self.setTunnelNetworkSettings(tunNs) - self.logger.info("Set remote tunnel address to \(tunNs.tunnelRemoteAddress)") - -// let tunFd = self.packetFlow.value(forKeyPath: "socket.fileDescriptor") as! Int; -// self.logger.info("Found File Descriptor: \(tunFd)") - let start_command = start_req_fd(id: 1, fd: 0) - guard let data = try await client?.request(start_command, type: Response>.self) - else { - throw BurrowError.cantParseResult - } - let encoded_startres = try JSONEncoder().encode(data.result) - self.logger.log("Received start server response: \(String(decoding: encoded_startres, as: UTF8.self))") - completionHandler(nil) - } catch { - self.logger.error("An error occurred: \(error)") - completionHandler(error) + do { + let command = BurrowSingleCommand(id: 0, command: "ServerConfig") + guard let data = try await client?.request(command, type: Response>.self) + else { + throw BurrowError.cantParseResult } + let encoded = try JSONEncoder().encode(data.result) + self.logger.log("Received final data: \(String(decoding: encoded, as: UTF8.self))") + guard let serverconfig = data.result.Ok else { + throw BurrowError.resultIsError + } + guard let tunNs = self.generateTunSettings(from: serverconfig) else { + throw BurrowError.addrDoesntExist + } + try await self.setTunnelNetworkSettings(tunNs) + self.logger.info("Set remote tunnel address to \(tunNs.tunnelRemoteAddress)") + + // let tunFd = self.packetFlow.value(forKeyPath: "socket.fileDescriptor") as! Int; + // self.logger.info("Found File Descriptor: \(tunFd)") + let start_command = start_req_fd(id: 1, fd: 0) + guard let data = try await client?.request(start_command, type: Response>.self) + else { + throw BurrowError.cantParseResult + } + let encoded_startres = try JSONEncoder().encode(data.result) + self.logger.log("Received start server response: \(String(decoding: encoded_startres, as: UTF8.self))") + } catch { + self.logger.error("An error occurred: \(error)") + throw error } } private func generateTunSettings(from: ServerConfigData) -> NETunnelNetworkSettings? { @@ -60,17 +57,16 @@ class PacketTunnelProvider: NEPacketTunnelProvider { logger.log("Initialized ipv4 settings: \(nst.ipv4Settings)") return nst } - override func stopTunnel(with reason: NEProviderStopReason, completionHandler: @escaping () -> Void) { - completionHandler() + override func stopTunnel(with reason: NEProviderStopReason) async { + } - override func handleAppMessage(_ messageData: Data, completionHandler: ((Data?) -> Void)?) { - if let handler = completionHandler { - handler(messageData) - } + override func handleAppMessage(_ messageData: Data) async -> Data? { + messageData } - override func sleep(completionHandler: @escaping () -> Void) { - completionHandler() + override func sleep() async { + } override func wake() { + } } diff --git a/burrow/src/apple.rs b/burrow/src/apple.rs index dd50fc2..571b413 100644 --- a/burrow/src/apple.rs +++ b/burrow/src/apple.rs @@ -1,8 +1,6 @@ -use tracing::instrument::WithSubscriber; -use tracing::{debug, Subscriber}; +use tracing::debug; use tracing_oslog::OsLogger; use tracing_subscriber::layer::SubscriberExt; -use tracing_subscriber::FmtSubscriber; pub use crate::daemon::start_srv; diff --git a/burrow/src/daemon/command.rs b/burrow/src/daemon/command.rs index 776e172..53b4108 100644 --- a/burrow/src/daemon/command.rs +++ b/burrow/src/daemon/command.rs @@ -23,10 +23,7 @@ fn test_daemoncommand_serialization() { .unwrap()); insta::assert_snapshot!( serde_json::to_string(&DaemonCommand::Start(DaemonStartOptions { - tun: TunOptions { - seek_utun: true, - ..TunOptions::default() - } + tun: TunOptions { ..TunOptions::default() } })) .unwrap() ); diff --git a/burrow/src/daemon/instance.rs b/burrow/src/daemon/instance.rs index c79da05..98052d2 100644 --- a/burrow/src/daemon/instance.rs +++ b/burrow/src/daemon/instance.rs @@ -1,10 +1,17 @@ -use super::*; -use crate::daemon::response::{DaemonResponseData, ServerConfig, ServerInfo}; -use tokio::task::JoinHandle; -use tracing::field::debug; +use std::sync::Arc; + +use anyhow::Result; +use tokio::{sync::RwLock, task::JoinHandle}; use tracing::{debug, info, warn}; use tun::tokio::TunInterface; -use DaemonResponse; + +use crate::{ + daemon::{ + command::DaemonCommand, + response::{DaemonResponse, DaemonResponseData, ServerConfig, ServerInfo}, + }, + wireguard::Interface, +}; enum RunState { Running(JoinHandle>), @@ -48,7 +55,9 @@ impl DaemonInstance { } RunState::Idle => { debug!("Creating new TunInterface"); - let tun_if = Arc::new(RwLock::new(TunInterface::new(st.tun.open()?)?)); + let retrieved = TunInterface::new(tun::TunInterface::retrieve().unwrap())?; + let tun_if = Arc::new(RwLock::new(retrieved)); + // let tun_if = Arc::new(RwLock::new(TunInterface::new(st.tun.open()?)?)); debug!("TunInterface created: {:?}", tun_if); debug!("Setting tun_interface"); diff --git a/burrow/src/daemon/mod.rs b/burrow/src/daemon/mod.rs index 1aa6ea4..394ebec 100644 --- a/burrow/src/daemon/mod.rs +++ b/burrow/src/daemon/mod.rs @@ -1,26 +1,29 @@ -use std::net::{IpAddr, Ipv4Addr, SocketAddr, SocketAddrV4, ToSocketAddrs}; -use std::sync::Arc; +use std::{ + net::{IpAddr, Ipv4Addr, SocketAddr}, + sync::Arc, +}; mod command; mod instance; mod net; mod response; -use crate::wireguard::{Interface, Peer, PublicKey, StaticSecret}; use anyhow::{Error, Result}; use base64::{engine::general_purpose, Engine as _}; pub use command::{DaemonCommand, DaemonStartOptions}; use fehler::throws; use instance::DaemonInstance; use ip_network::{IpNetwork, Ipv4Network}; -pub use net::DaemonClient; -use tokio::sync::RwLock; - #[cfg(target_vendor = "apple")] pub use net::start_srv; - -use crate::daemon::net::listen; +pub use net::DaemonClient; pub use response::{DaemonResponse, DaemonResponseData, ServerInfo}; +use tokio::sync::RwLock; + +use crate::{ + daemon::net::listen, + wireguard::{Interface, Peer, PublicKey, StaticSecret}, +}; #[throws] fn parse_key(string: &str) -> [u8; 32] { diff --git a/burrow/src/daemon/net/apple.rs b/burrow/src/daemon/net/apple.rs index 1242dfe..b84ec08 100644 --- a/burrow/src/daemon/net/apple.rs +++ b/burrow/src/daemon/net/apple.rs @@ -1,9 +1,10 @@ -use crate::daemon::{daemon_main, DaemonClient}; -use std::future::Future; use std::thread; + use tokio::runtime::Runtime; use tracing::{error, info}; +use crate::daemon::{daemon_main, DaemonClient}; + #[no_mangle] pub extern "C" fn start_srv() { info!("Rust: Starting server"); @@ -20,7 +21,7 @@ pub extern "C" fn start_srv() { loop { match DaemonClient::new().await { Ok(_) => break, - Err(e) => { + Err(_e) => { // error!("Error when connecting to daemon: {}", e) } } diff --git a/burrow/src/daemon/net/unix.rs b/burrow/src/daemon/net/unix.rs index d0e5b26..7ce8992 100644 --- a/burrow/src/daemon/net/unix.rs +++ b/burrow/src/daemon/net/unix.rs @@ -1,25 +1,21 @@ -use super::*; -use anyhow::anyhow; -use log::log; -use std::hash::Hash; -use std::path::PathBuf; use std::{ - ascii, io, + io, os::{ fd::{FromRawFd, RawFd}, unix::net::UnixListener as StdUnixListener, }, - path::Path, + path::{Path, PathBuf}, }; -use tracing::info; -use crate::daemon::{DaemonCommand, DaemonResponse, DaemonResponseData}; -use anyhow::Result; +use anyhow::{anyhow, Result}; use tokio::{ io::{AsyncBufReadExt, AsyncWriteExt, BufReader}, net::{UnixListener, UnixStream}, }; -use tracing::debug; +use tracing::{debug, info}; + +use super::*; +use crate::daemon::{DaemonCommand, DaemonResponse, DaemonResponseData}; #[cfg(not(target_vendor = "apple"))] const UNIX_SOCKET_PATH: &str = "/run/burrow.sock"; @@ -40,7 +36,7 @@ fn fetch_socket_path() -> Option { for path in tries { let path = PathBuf::from(path); if path.exists() { - return Some(path); + return Some(path) } } None diff --git a/burrow/src/daemon/response.rs b/burrow/src/daemon/response.rs index 63d10e8..4bebe14 100644 --- a/burrow/src/daemon/response.rs +++ b/burrow/src/daemon/response.rs @@ -1,4 +1,3 @@ -use anyhow::anyhow; use schemars::JsonSchema; use serde::{Deserialize, Serialize}; use tun::TunInterface; diff --git a/burrow/src/ensureroot.rs b/burrow/src/ensureroot.rs deleted file mode 100644 index b7c0757..0000000 --- a/burrow/src/ensureroot.rs +++ /dev/null @@ -1,40 +0,0 @@ -use tracing::instrument; - -// Check capabilities on Linux -#[cfg(target_os = "linux")] -#[instrument] -pub fn ensure_root() { - use caps::{has_cap, CapSet, Capability}; - - let cap_net_admin = Capability::CAP_NET_ADMIN; - if let Ok(has_cap) = has_cap(None, CapSet::Effective, cap_net_admin) { - if !has_cap { - eprintln!( - "This action needs the CAP_NET_ADMIN permission. Did you mean to run it as root?" - ); - std::process::exit(77); - } - } else { - eprintln!("Failed to check capabilities. Please file a bug report!"); - std::process::exit(71); - } -} - -// Check for root user on macOS -#[cfg(target_vendor = "apple")] -#[instrument] -pub fn ensure_root() { - use nix::unistd::Uid; - - let current_uid = Uid::current(); - if !current_uid.is_root() { - eprintln!("This action must be run as root!"); - std::process::exit(77); - } -} - -#[cfg(target_family = "windows")] -#[instrument] -pub fn ensure_root() { - todo!() -} diff --git a/burrow/src/lib.rs b/burrow/src/lib.rs index 07cb2f6..9df60f0 100644 --- a/burrow/src/lib.rs +++ b/burrow/src/lib.rs @@ -1,21 +1,12 @@ -pub mod ensureroot; pub mod wireguard; -use anyhow::Result; - -#[cfg(any(target_os = "linux", target_vendor = "apple"))] -use std::{ - mem, - os::fd::{AsRawFd, FromRawFd}, -}; - -use tun::TunInterface; - -// TODO Separate start and retrieve functions - mod daemon; pub use daemon::{ - DaemonCommand, DaemonResponse, DaemonResponseData, DaemonStartOptions, ServerInfo, + DaemonCommand, + DaemonResponse, + DaemonResponseData, + DaemonStartOptions, + ServerInfo, }; #[cfg(target_vendor = "apple")] diff --git a/burrow/src/main.rs b/burrow/src/main.rs index ff0ed53..5003277 100644 --- a/burrow/src/main.rs +++ b/burrow/src/main.rs @@ -1,24 +1,21 @@ -use std::mem; -#[cfg(any(target_os = "linux", target_vendor = "apple"))] -use std::os::fd::FromRawFd; - use anyhow::{Context, Result}; #[cfg(any(target_os = "linux", target_vendor = "apple"))] use clap::{Args, Parser, Subcommand}; -use tracing::{instrument, Level}; +use tracing::instrument; use tracing_log::LogTracer; use tracing_oslog::OsLogger; use tracing_subscriber::{prelude::*, EnvFilter, FmtSubscriber}; #[cfg(any(target_os = "linux", target_vendor = "apple"))] -use tun::{retrieve, TunInterface}; +use tun::TunInterface; mod daemon; mod wireguard; -use crate::daemon::DaemonResponseData; use daemon::{DaemonClient, DaemonCommand, DaemonStartOptions}; use tun::TunOptions; +use crate::daemon::DaemonResponseData; + #[derive(Parser)] #[command(name = "Burrow")] #[command(author = "Hack Club ")] @@ -88,8 +85,7 @@ async fn try_retrieve() -> Result<()> { } } - burrow::ensureroot::ensure_root(); - let iface2 = retrieve().ok_or(anyhow::anyhow!("No interface found"))?; + let iface2 = TunInterface::retrieve().ok_or(anyhow::anyhow!("No interface found"))?; tracing::info!("{:?}", iface2); Ok(()) } diff --git a/burrow/src/wireguard/iface.rs b/burrow/src/wireguard/iface.rs index 52f719b..9a0c216 100755 --- a/burrow/src/wireguard/iface.rs +++ b/burrow/src/wireguard/iface.rs @@ -1,24 +1,15 @@ -use std::sync::Arc; -use std::time::Duration; -use std::{net::IpAddr, rc::Rc}; +use std::{net::IpAddr, sync::Arc, time::Duration}; use anyhow::Error; use async_trait::async_trait; use fehler::throws; -use futures::future::join_all; -use futures::FutureExt; +use futures::{future::join_all, FutureExt}; use ip_network_table::IpNetworkTable; -use log::log; -use tokio::time::timeout; -use tokio::{ - join, - sync::{Mutex, RwLock}, - task::{self, JoinHandle}, -}; +use tokio::{sync::RwLock, task::JoinHandle, time::timeout}; use tracing::{debug, error}; use tun::tokio::TunInterface; -use super::{noise::Tunnel, pcb, Peer, PeerPcb}; +use super::{noise::Tunnel, Peer, PeerPcb}; #[async_trait] pub trait PacketInterface { @@ -122,12 +113,9 @@ impl Interface { Ok(Ok(len)) => &buf[..len], Ok(Err(e)) => { error!("failed to read from interface: {}", e); - continue; - } - Err(_would_block) => { - debug!("read timed out"); - continue; + continue } + Err(_would_block) => continue, }; debug!("read {} bytes from interface", src.len()); debug!("bytes: {:?}", src); @@ -138,7 +126,7 @@ impl Interface { Some(addr) => addr, None => { tracing::debug!("no destination found"); - continue; + continue } }; @@ -156,7 +144,7 @@ impl Interface { } Err(e) => { log::error!("failed to send packet {}", e); - continue; + continue } }; } @@ -175,20 +163,20 @@ impl Interface { let pcbs = &self.pcbs; for i in 0..pcbs.pcbs.len() { debug!("spawning read task for peer {}", i); - let mut pcb = pcbs.pcbs[i].clone(); + let pcb = pcbs.pcbs[i].clone(); let tun = tun.clone(); let tsk = async move { { let r1 = pcb.write().await.open_if_closed().await; if let Err(e) = r1 { log::error!("failed to open pcb: {}", e); - return; + return } } let r2 = pcb.read().await.run(tun).await; if let Err(e) = r2 { log::error!("failed to run pcb: {}", e); - return; + return } else { tracing::debug!("pcb ran successfully"); } diff --git a/burrow/src/wireguard/noise/handshake.rs b/burrow/src/wireguard/noise/handshake.rs index 3f8c91b..2ec0c6a 100755 --- a/burrow/src/wireguard/noise/handshake.rs +++ b/burrow/src/wireguard/noise/handshake.rs @@ -9,14 +9,20 @@ use std::{ use aead::{Aead, Payload}; use blake2::{ digest::{FixedOutput, KeyInit}, - Blake2s256, Blake2sMac, Digest, + Blake2s256, + Blake2sMac, + Digest, }; use chacha20poly1305::XChaCha20Poly1305; use rand_core::OsRng; use ring::aead::{Aad, LessSafeKey, Nonce, UnboundKey, CHACHA20_POLY1305}; use super::{ - errors::WireGuardError, session::Session, x25519, HandshakeInit, HandshakeResponse, + errors::WireGuardError, + session::Session, + x25519, + HandshakeInit, + HandshakeResponse, PacketCookieReply, }; @@ -130,10 +136,6 @@ fn aead_chacha20_open( ) -> Result<(), WireGuardError> { let mut nonce: [u8; 12] = [0; 12]; nonce[4..].copy_from_slice(&counter.to_le_bytes()); - - tracing::debug!("TAG A"); - tracing::debug!("{:?};{:?};{:?};{:?};{}", key, data, aad, nonce, counter); - aead_chacha20_open_inner(buffer, key, nonce, data, aad) .map_err(|_| WireGuardError::InvalidAeadTag)?; Ok(()) @@ -207,7 +209,7 @@ impl Tai64N { /// Parse a timestamp from a 12 byte u8 slice fn parse(buf: &[u8; 12]) -> Result { if buf.len() < 12 { - return Err(WireGuardError::InvalidTai64nTimestamp); + return Err(WireGuardError::InvalidTai64nTimestamp) } let (sec_bytes, nano_bytes) = buf.split_at(std::mem::size_of::()); @@ -554,22 +556,19 @@ impl Handshake { let timestamp = Tai64N::parse(×tamp)?; if !timestamp.after(&self.last_handshake_timestamp) { // Possibly a replay - return Err(WireGuardError::WrongTai64nTimestamp); + return Err(WireGuardError::WrongTai64nTimestamp) } self.last_handshake_timestamp = timestamp; // initiator.hash = HASH(initiator.hash || msg.encrypted_timestamp) hash = b2s_hash(&hash, packet.encrypted_timestamp); - self.previous = std::mem::replace( - &mut self.state, - HandshakeState::InitReceived { - chaining_key, - hash, - peer_ephemeral_public, - peer_index, - }, - ); + self.previous = std::mem::replace(&mut self.state, HandshakeState::InitReceived { + chaining_key, + hash, + peer_ephemeral_public, + peer_index, + }); self.format_handshake_response(dst) } @@ -670,7 +669,7 @@ impl Handshake { let local_index = self.cookies.index; if packet.receiver_idx != local_index { - return Err(WireGuardError::WrongIndex); + return Err(WireGuardError::WrongIndex) } // msg.encrypted_cookie = XAEAD(HASH(LABEL_COOKIE || responder.static_public), // msg.nonce, cookie, last_received_msg.mac1) @@ -680,7 +679,6 @@ impl Handshake { aad: &mac1[0..16], msg: packet.encrypted_cookie, }; - tracing::debug!("TAG B"); let plaintext = XChaCha20Poly1305::new_from_slice(&key) .unwrap() .decrypt(packet.nonce.into(), payload) @@ -727,7 +725,7 @@ impl Handshake { dst: &'a mut [u8], ) -> Result<&'a mut [u8], WireGuardError> { if dst.len() < super::HANDSHAKE_INIT_SZ { - return Err(WireGuardError::DestinationBufferTooSmall); + return Err(WireGuardError::DestinationBufferTooSmall) } let (message_type, rest) = dst.split_at_mut(4); @@ -810,7 +808,7 @@ impl Handshake { dst: &'a mut [u8], ) -> Result<(&'a mut [u8], Session), WireGuardError> { if dst.len() < super::HANDSHAKE_RESP_SZ { - return Err(WireGuardError::DestinationBufferTooSmall); + return Err(WireGuardError::DestinationBufferTooSmall) } let state = std::mem::replace(&mut self.state, HandshakeState::None); diff --git a/burrow/src/wireguard/noise/mod.rs b/burrow/src/wireguard/noise/mod.rs index 824d7c1..3a60c22 100755 --- a/burrow/src/wireguard/noise/mod.rs +++ b/burrow/src/wireguard/noise/mod.rs @@ -45,7 +45,11 @@ const N_SESSIONS: usize = 8; pub mod x25519 { pub use x25519_dalek::{ - EphemeralSecret, PublicKey, ReusableSecret, SharedSecret, StaticSecret, + EphemeralSecret, + PublicKey, + ReusableSecret, + SharedSecret, + StaticSecret, }; } @@ -137,7 +141,7 @@ impl Tunnel { #[inline(always)] pub fn parse_incoming_packet(src: &[u8]) -> Result { if src.len() < 4 { - return Err(WireGuardError::InvalidPacket); + return Err(WireGuardError::InvalidPacket) } // Checks the type, as well as the reserved zero fields @@ -179,7 +183,7 @@ impl Tunnel { pub fn dst_address(packet: &[u8]) -> Option { if packet.is_empty() { - return None; + return None } match packet[0] >> 4 { @@ -274,7 +278,7 @@ impl Tunnel { self.timer_tick(TimerName::TimeLastDataPacketSent); } self.tx_bytes += src.len(); - return TunnResult::WriteToNetwork(packet); + return TunnResult::WriteToNetwork(packet) } // If there is no session, queue the packet for future retry @@ -298,7 +302,7 @@ impl Tunnel { ) -> TunnResult<'a> { if datagram.is_empty() { // Indicates a repeated call - return self.send_queued_packet(dst); + return self.send_queued_packet(dst) } let mut cookie = [0u8; COOKIE_REPLY_SZ]; @@ -309,7 +313,7 @@ impl Tunnel { Ok(packet) => packet, Err(TunnResult::WriteToNetwork(cookie)) => { dst[..cookie.len()].copy_from_slice(cookie); - return TunnResult::WriteToNetwork(&mut dst[..cookie.len()]); + return TunnResult::WriteToNetwork(&mut dst[..cookie.len()]) } Err(TunnResult::Err(e)) => return TunnResult::Err(e), _ => unreachable!(), @@ -409,7 +413,7 @@ impl Tunnel { let cur_idx = self.current; if cur_idx == new_idx { // There is nothing to do, already using this session, this is the common case - return; + return } if self.sessions[cur_idx % N_SESSIONS].is_none() || self.timers.session_timers[new_idx % N_SESSIONS] @@ -455,7 +459,7 @@ impl Tunnel { force_resend: bool, ) -> TunnResult<'a> { if self.handshake.is_in_progress() && !force_resend { - return TunnResult::Done; + return TunnResult::Done } if self.handshake.is_expired() { @@ -514,7 +518,7 @@ impl Tunnel { }; if computed_len > packet.len() { - return TunnResult::Err(WireGuardError::InvalidPacket); + return TunnResult::Err(WireGuardError::InvalidPacket) } self.timer_tick(TimerName::TimeLastDataPacketReceived); diff --git a/burrow/src/wireguard/noise/rate_limiter.rs b/burrow/src/wireguard/noise/rate_limiter.rs index 02887ee..ff19efd 100755 --- a/burrow/src/wireguard/noise/rate_limiter.rs +++ b/burrow/src/wireguard/noise/rate_limiter.rs @@ -6,16 +6,25 @@ use std::{ use aead::{generic_array::GenericArray, AeadInPlace, KeyInit}; use chacha20poly1305::{Key, XChaCha20Poly1305}; -use log::log; use parking_lot::Mutex; use rand_core::{OsRng, RngCore}; use ring::constant_time::verify_slices_are_equal; use super::{ handshake::{ - b2s_hash, b2s_keyed_mac_16, b2s_keyed_mac_16_2, b2s_mac_24, LABEL_COOKIE, LABEL_MAC1, + b2s_hash, + b2s_keyed_mac_16, + b2s_keyed_mac_16_2, + b2s_mac_24, + LABEL_COOKIE, + LABEL_MAC1, }, - HandshakeInit, HandshakeResponse, Packet, TunnResult, Tunnel, WireGuardError, + HandshakeInit, + HandshakeResponse, + Packet, + TunnResult, + Tunnel, + WireGuardError, }; const COOKIE_REFRESH: u64 = 128; // Use 128 and not 120 so the compiler can optimize out the division @@ -127,7 +136,7 @@ impl RateLimiter { dst: &'a mut [u8], ) -> Result<&'a mut [u8], WireGuardError> { if dst.len() < super::COOKIE_REPLY_SZ { - return Err(WireGuardError::DestinationBufferTooSmall); + return Err(WireGuardError::DestinationBufferTooSmall) } let (message_type, rest) = dst.split_at_mut(4); @@ -193,7 +202,7 @@ impl RateLimiter { let cookie_packet = self .format_cookie_reply(sender_idx, cookie, mac1, dst) .map_err(TunnResult::Err)?; - return Err(TunnResult::WriteToNetwork(cookie_packet)); + return Err(TunnResult::WriteToNetwork(cookie_packet)) } } } diff --git a/burrow/src/wireguard/noise/session.rs b/burrow/src/wireguard/noise/session.rs index 14c191b..8988728 100755 --- a/burrow/src/wireguard/noise/session.rs +++ b/burrow/src/wireguard/noise/session.rs @@ -88,11 +88,11 @@ impl ReceivingKeyCounterValidator { fn will_accept(&self, counter: u64) -> Result<(), WireGuardError> { if counter >= self.next { // As long as the counter is growing no replay took place for sure - return Ok(()); + return Ok(()) } if counter + N_BITS < self.next { // Drop if too far back - return Err(WireGuardError::InvalidCounter); + return Err(WireGuardError::InvalidCounter) } if !self.check_bit(counter) { Ok(()) @@ -107,22 +107,22 @@ impl ReceivingKeyCounterValidator { fn mark_did_receive(&mut self, counter: u64) -> Result<(), WireGuardError> { if counter + N_BITS < self.next { // Drop if too far back - return Err(WireGuardError::InvalidCounter); + return Err(WireGuardError::InvalidCounter) } if counter == self.next { // Usually the packets arrive in order, in that case we simply mark the bit and // increment the counter self.set_bit(counter); self.next += 1; - return Ok(()); + return Ok(()) } if counter < self.next { // A packet arrived out of order, check if it is valid, and mark if self.check_bit(counter) { - return Err(WireGuardError::InvalidCounter); + return Err(WireGuardError::InvalidCounter) } self.set_bit(counter); - return Ok(()); + return Ok(()) } // Packets where dropped, or maybe reordered, skip them and mark unused if counter - self.next >= N_BITS { @@ -247,7 +247,7 @@ impl Session { panic!("The destination buffer is too small"); } if packet.receiver_idx != self.receiving_index { - return Err(WireGuardError::WrongIndex); + return Err(WireGuardError::WrongIndex) } // Don't reuse counters, in case this is a replay attack we want to quickly // check the counter without running expensive decryption diff --git a/burrow/src/wireguard/noise/timers.rs b/burrow/src/wireguard/noise/timers.rs index f713e6f..1d0cf1f 100755 --- a/burrow/src/wireguard/noise/timers.rs +++ b/burrow/src/wireguard/noise/timers.rs @@ -190,7 +190,7 @@ impl Tunnel { { if self.handshake.is_expired() { - return TunnResult::Err(WireGuardError::ConnectionExpired); + return TunnResult::Err(WireGuardError::ConnectionExpired) } // Clear cookie after COOKIE_EXPIRATION_TIME @@ -206,7 +206,7 @@ impl Tunnel { tracing::error!("CONNECTION_EXPIRED(REJECT_AFTER_TIME * 3)"); self.handshake.set_expired(); self.clear_all(); - return TunnResult::Err(WireGuardError::ConnectionExpired); + return TunnResult::Err(WireGuardError::ConnectionExpired) } if let Some(time_init_sent) = self.handshake.timer() { @@ -219,7 +219,7 @@ impl Tunnel { tracing::error!("CONNECTION_EXPIRED(REKEY_ATTEMPT_TIME)"); self.handshake.set_expired(); self.clear_all(); - return TunnResult::Err(WireGuardError::ConnectionExpired); + return TunnResult::Err(WireGuardError::ConnectionExpired) } if time_init_sent.elapsed() >= REKEY_TIMEOUT { @@ -299,11 +299,11 @@ impl Tunnel { } if handshake_initiation_required { - return self.format_handshake_initiation(dst, true); + return self.format_handshake_initiation(dst, true) } if keepalive_required { - return self.encapsulate(&[], dst); + return self.encapsulate(&[], dst) } TunnResult::Done diff --git a/burrow/src/wireguard/pcb.rs b/burrow/src/wireguard/pcb.rs index 6acc8d8..d11e736 100755 --- a/burrow/src/wireguard/pcb.rs +++ b/burrow/src/wireguard/pcb.rs @@ -1,19 +1,11 @@ -use std::io; -use std::net::SocketAddr; -use std::rc::Rc; -use std::sync::Arc; -use std::time::Duration; +use std::{net::SocketAddr, sync::Arc}; use anyhow::{anyhow, Error}; use fehler::throws; use ip_network::IpNetwork; -use log::log; use rand::random; -use tokio::sync::{Mutex, RwLock}; -use tokio::time::timeout; -use tokio::{net::UdpSocket, task::JoinHandle}; +use tokio::{net::UdpSocket, sync::RwLock, task::JoinHandle}; use tun::tokio::TunInterface; -use uuid::uuid; use super::{ iface::PacketInterface, @@ -83,16 +75,14 @@ impl PeerPcb { tracing::debug!("start read loop {}", rid); loop { tracing::debug!("{}: waiting for packet", rid); - let Some(socket) = &self.socket else { - continue - }; + let Some(socket) = &self.socket else { continue }; let mut res_buf = [0; 1500]; // tracing::debug!("{} : waiting for readability on {:?}", rid, socket); let len = match socket.recv(&mut res_buf).await { Ok(l) => l, Err(e) => { log::error!("{}: error reading from socket: {:?}", rid, e); - continue; + continue } }; let mut res_dat = &res_buf[..len]; @@ -105,33 +95,31 @@ impl PeerPcb { .await .decapsulate(None, res_dat, &mut buf[..]) { - TunnResult::Done => { - break; - } + TunnResult::Done => break, TunnResult::Err(e) => { tracing::error!(message = "Decapsulate error", error = ?e); - break; + break } TunnResult::WriteToNetwork(packet) => { tracing::debug!("WriteToNetwork: {:?}", packet); socket.send(packet).await?; tracing::debug!("WriteToNetwork done"); res_dat = &[]; - continue; + continue } TunnResult::WriteToTunnelV4(packet, addr) => { tracing::debug!("WriteToTunnelV4: {:?}, {:?}", packet, addr); tun_interface.read().await.send(packet).await?; - break; + break } TunnResult::WriteToTunnelV6(packet, addr) => { tracing::debug!("WriteToTunnelV6: {:?}, {:?}", packet, addr); tun_interface.read().await.send(packet).await?; - break; + break } } } - return Ok(len); + return Ok(len) } } diff --git a/burrow/src/wireguard/peer.rs b/burrow/src/wireguard/peer.rs index 27c5399..131b0d4 100755 --- a/burrow/src/wireguard/peer.rs +++ b/burrow/src/wireguard/peer.rs @@ -1,7 +1,5 @@ use std::{fmt, net::SocketAddr}; -use anyhow::Error; -use fehler::throws; use ip_network::IpNetwork; use x25519_dalek::{PublicKey, StaticSecret}; diff --git a/tun/src/lib.rs b/tun/src/lib.rs index 64e17df..a1ca636 100644 --- a/tun/src/lib.rs +++ b/tun/src/lib.rs @@ -15,4 +15,4 @@ mod options; pub mod tokio; pub use options::TunOptions; -pub use os_imp::{retrieve, TunInterface, TunQueue}; +pub use os_imp::{TunInterface, TunQueue}; diff --git a/tun/src/options.rs b/tun/src/options.rs index aafdad2..7c414dc 100644 --- a/tun/src/options.rs +++ b/tun/src/options.rs @@ -16,8 +16,6 @@ pub struct TunOptions { pub no_pi: bool, /// (Linux) Avoid opening an existing persistant device. pub tun_excl: bool, - /// (MacOS) Whether to seek the first available utun device. - pub seek_utun: Option, /// (Linux) The IP address of the tun interface. pub address: Option, } diff --git a/tun/src/tokio/mod.rs b/tun/src/tokio/mod.rs index c901cba..525e4d7 100644 --- a/tun/src/tokio/mod.rs +++ b/tun/src/tokio/mod.rs @@ -39,11 +39,11 @@ impl TunInterface { }) { Ok(result) => { tracing::debug!("HORRAY"); - return result; + return result } Err(_would_block) => { tracing::debug!("WouldBlock"); - continue; + continue } } } diff --git a/tun/src/unix/apple/mod.rs b/tun/src/unix/apple/mod.rs index ca3ddc7..ab08505 100644 --- a/tun/src/unix/apple/mod.rs +++ b/tun/src/unix/apple/mod.rs @@ -1,24 +1,24 @@ use std::{ io::{Error, IoSlice}, - mem, + mem::{self, ManuallyDrop}, net::{Ipv4Addr, SocketAddrV4}, - os::fd::{AsRawFd, RawFd}, + os::fd::{AsRawFd, FromRawFd, RawFd}, }; use byteorder::{ByteOrder, NetworkEndian}; use fehler::throws; use libc::{c_char, iovec, writev, AF_INET, AF_INET6}; use socket2::{Domain, SockAddr, Socket, Type}; -use tracing::{self, debug, instrument}; +use tracing::{self, instrument}; -mod kern_control; +pub mod kern_control; pub mod sys; -use crate::retrieve; use kern_control::SysControlSocket; pub use super::queue::TunQueue; -use super::{ifname_to_string, string_to_ifname, TunOptions}; +use super::{ifname_to_string, string_to_ifname}; +use crate::TunOptions; #[derive(Debug)] pub struct TunInterface { @@ -35,18 +35,41 @@ impl TunInterface { #[throws] #[instrument] pub fn new_with_options(options: TunOptions) -> TunInterface { - debug!("Opening tun interface with options: {:?}", &options); - let ti = if let Some(n) = options.seek_utun { - retrieve().ok_or(Error::new(std::io::ErrorKind::NotFound, "No utun found"))? - } else { - TunInterface::connect(0)? - }; + let ti = TunInterface::connect(0)?; + ti.configure(options)?; + ti + } + + pub fn retrieve() -> Option { + (3..100) + .filter_map(|fd| unsafe { + let peer_addr = socket2::SockAddr::init(|storage, len| { + *len = mem::size_of::() as u32; + libc::getpeername(fd, storage as *mut _, len); + Ok(()) + }) + .map(|(_, addr)| (fd, addr)); + peer_addr.ok() + }) + .filter(|(_fd, addr)| { + let ctl_addr = unsafe { &*(addr.as_ptr() as *const libc::sockaddr_ctl) }; + addr.family() == libc::AF_SYSTEM as u8 + && ctl_addr.ss_sysaddr == libc::AF_SYS_CONTROL as u16 + }) + .map(|(fd, _)| { + let socket = unsafe { socket2::Socket::from_raw_fd(fd) }; + TunInterface { socket } + }) + .next() + } + + #[throws] + fn configure(&self, options: TunOptions) { if let Some(addr) = options.address { if let Ok(addr) = addr.parse() { - ti.set_ipv4_addr(addr)?; + self.set_ipv4_addr(addr)?; } } - ti } #[throws] diff --git a/tun/src/unix/mod.rs b/tun/src/unix/mod.rs index 72a8795..775ba1d 100644 --- a/tun/src/unix/mod.rs +++ b/tun/src/unix/mod.rs @@ -1,13 +1,9 @@ -use std::mem::size_of; use std::{ io::{Error, Read}, - mem, os::fd::{AsRawFd, FromRawFd, IntoRawFd, RawFd}, }; -use tracing::{debug, error, instrument}; - -use super::{syscall, TunOptions}; +use tracing::instrument; mod queue; @@ -19,13 +15,9 @@ mod imp; #[path = "linux/mod.rs"] mod imp; -use crate::os_imp::imp::sys; -use crate::os_imp::imp::sys::resolve_ctl_info; use fehler::throws; pub use imp::TunInterface; -use libc::{getpeername, sockaddr_ctl, sockaddr_storage, socklen_t, AF_SYSTEM, AF_SYS_CONTROL}; pub use queue::TunQueue; -use socket2::SockAddr; impl AsRawFd for TunInterface { fn as_raw_fd(&self) -> RawFd { @@ -82,35 +74,3 @@ pub fn string_to_ifname(name: &str) -> [libc::c_char; libc::IFNAMSIZ] { buf[..len].copy_from_slice(unsafe { &*(name.as_bytes() as *const _ as *const [libc::c_char]) }); buf } - -#[cfg(any(target_os = "linux", target_vendor = "apple"))] -pub fn retrieve() -> Option { - (3..100) - .filter_map(|i| { - let result = unsafe { - let mut addr = sockaddr_ctl { - sc_len: size_of::() as u8, - sc_family: 0, - ss_sysaddr: 0, - sc_id: 0, - sc_unit: 0, - sc_reserved: Default::default(), - }; - let mut len = mem::size_of::() as libc::socklen_t; - let res = syscall!(getpeername(i, &mut addr as *mut _ as *mut _, len as *mut _)); - tracing::debug!("getpeername{}: {:?}", i, res); - if res.is_err() { - return None; - } - if addr.sc_family == sys::AF_SYSTEM as u8 - && addr.ss_sysaddr == sys::AF_SYS_CONTROL as u16 - { - Some(TunInterface::from_raw_fd(i)) - } else { - None - } - }; - result - }) - .next() -}