diff --git a/Apple/NetworkExtension/BurrowIpc.swift b/Apple/NetworkExtension/BurrowIpc.swift index 279cdf1..7f18679 100644 --- a/Apple/NetworkExtension/BurrowIpc.swift +++ b/Apple/NetworkExtension/BurrowIpc.swift @@ -113,7 +113,7 @@ final class BurrowIpc { return data } - func request(_ request: any Request, type: U.Type) async throws -> U { + func request(_ request: Request, type: U.Type) async throws -> U { do { var data: Data = try JSONEncoder().encode(request) data.append(contentsOf: [10]) diff --git a/Apple/NetworkExtension/DataTypes.swift b/Apple/NetworkExtension/DataTypes.swift index 6b3a070..b228d77 100644 --- a/Apple/NetworkExtension/DataTypes.swift +++ b/Apple/NetworkExtension/DataTypes.swift @@ -7,40 +7,16 @@ enum BurrowError: Error { case resultIsNone } -protocol Request: Codable where T: Codable{ - associatedtype T +protocol Request: Codable { var id: UInt { get set } - var command: T { get set } + var command: String { get set } } -struct BurrowSingleCommand: Request { +struct BurrowRequest: Request { var id: UInt var command: String } -struct BurrowRequest: Request where T: Codable{ - var id: UInt - var command: T -} - -struct BurrowStartRequest: Codable { - struct TunOptions: Codable{ - let name: String? - let no_pi: Bool - let tun_excl: Bool - let seek_utun: Int? - let address: String? - } - struct StartOptions: Codable{ - let tun: TunOptions - } - let Start: StartOptions -} - -func start_req_fd(id: UInt, fd: Int) -> BurrowRequest { - return BurrowRequest(id: id, command: BurrowStartRequest(Start: BurrowStartRequest.StartOptions(tun: BurrowStartRequest.TunOptions(name: nil, no_pi: false, tun_excl: false, seek_utun: fd, address: nil)))) -} - struct Response: Decodable where T: Decodable { var id: UInt var result: T diff --git a/Apple/NetworkExtension/PacketTunnelProvider.swift b/Apple/NetworkExtension/PacketTunnelProvider.swift index e9c48dd..4b72115 100644 --- a/Apple/NetworkExtension/PacketTunnelProvider.swift +++ b/Apple/NetworkExtension/PacketTunnelProvider.swift @@ -6,7 +6,7 @@ class PacketTunnelProvider: NEPacketTunnelProvider { let logger = Logger(subsystem: "com.hackclub.burrow", category: "frontend") var client: BurrowIpc? var osInitialized = false - override func startTunnel(options: [String : NSObject]? = nil) async throws { + override func startTunnel(options: [String: NSObject]?, completionHandler: @escaping (Error?) -> Void) { logger.log("Starting tunnel") if !osInitialized { libburrow.initialize_oslog() @@ -15,35 +15,28 @@ class PacketTunnelProvider: NEPacketTunnelProvider { libburrow.start_srv() client = BurrowIpc(logger: logger) logger.info("Started server") - do { - let command = BurrowSingleCommand(id: 0, command: "ServerConfig") - guard let data = try await client?.request(command, type: Response>.self) - else { - throw BurrowError.cantParseResult + Task { + do { + let command = BurrowRequest(id: 0, command: "ServerConfig") + guard let data = try await client?.request(command, type: Response>.self) + else { + throw BurrowError.cantParseResult + } + let encoded = try JSONEncoder().encode(data.result) + self.logger.log("Received final data: \(String(decoding: encoded, as: UTF8.self))") + guard let serverconfig = data.result.Ok else { + throw BurrowError.resultIsError + } + guard let tunNs = self.generateTunSettings(from: serverconfig) else { + throw BurrowError.addrDoesntExist + } + try await self.setTunnelNetworkSettings(tunNs) + self.logger.info("Set remote tunnel address to \(tunNs.tunnelRemoteAddress)") + completionHandler(nil) + } catch { + self.logger.error("An error occurred: \(error)") + completionHandler(error) } - let encoded = try JSONEncoder().encode(data.result) - self.logger.log("Received final data: \(String(decoding: encoded, as: UTF8.self))") - guard let serverconfig = data.result.Ok else { - throw BurrowError.resultIsError - } - guard let tunNs = self.generateTunSettings(from: serverconfig) else { - throw BurrowError.addrDoesntExist - } - try await self.setTunnelNetworkSettings(tunNs) - self.logger.info("Set remote tunnel address to \(tunNs.tunnelRemoteAddress)") - - // let tunFd = self.packetFlow.value(forKeyPath: "socket.fileDescriptor") as! Int; - // self.logger.info("Found File Descriptor: \(tunFd)") - let start_command = start_req_fd(id: 1, fd: 0) - guard let data = try await client?.request(start_command, type: Response>.self) - else { - throw BurrowError.cantParseResult - } - let encoded_startres = try JSONEncoder().encode(data.result) - self.logger.log("Received start server response: \(String(decoding: encoded_startres, as: UTF8.self))") - } catch { - self.logger.error("An error occurred: \(error)") - throw error } } private func generateTunSettings(from: ServerConfigData) -> NETunnelNetworkSettings? { @@ -57,16 +50,17 @@ class PacketTunnelProvider: NEPacketTunnelProvider { logger.log("Initialized ipv4 settings: \(nst.ipv4Settings)") return nst } - override func stopTunnel(with reason: NEProviderStopReason) async { - + override func stopTunnel(with reason: NEProviderStopReason, completionHandler: @escaping () -> Void) { + completionHandler() } - override func handleAppMessage(_ messageData: Data) async -> Data? { - messageData + override func handleAppMessage(_ messageData: Data, completionHandler: ((Data?) -> Void)?) { + if let handler = completionHandler { + handler(messageData) + } } - override func sleep() async { - + override func sleep(completionHandler: @escaping () -> Void) { + completionHandler() } override func wake() { - } } diff --git a/burrow/src/apple.rs b/burrow/src/apple.rs index 571b413..0a96877 100644 --- a/burrow/src/apple.rs +++ b/burrow/src/apple.rs @@ -1,13 +1,15 @@ -use tracing::debug; +use tracing::{debug, Subscriber}; +use tracing::instrument::WithSubscriber; use tracing_oslog::OsLogger; +use tracing_subscriber::FmtSubscriber; use tracing_subscriber::layer::SubscriberExt; pub use crate::daemon::start_srv; #[no_mangle] pub extern "C" fn initialize_oslog() { - let collector = - tracing_subscriber::registry().with(OsLogger::new("com.hackclub.burrow", "backend")); + let collector = tracing_subscriber::registry() + .with(OsLogger::new("com.hackclub.burrow", "backend")); tracing::subscriber::set_global_default(collector).unwrap(); debug!("Initialized oslog tracing in libburrow rust FFI"); -} +} \ No newline at end of file diff --git a/burrow/src/daemon/command.rs b/burrow/src/daemon/command.rs index 53b4108..a5a1f30 100644 --- a/burrow/src/daemon/command.rs +++ b/burrow/src/daemon/command.rs @@ -12,22 +12,21 @@ pub enum DaemonCommand { #[derive(Debug, Clone, Default, Serialize, Deserialize, JsonSchema)] pub struct DaemonStartOptions { - pub tun: TunOptions, + pub(super) tun: TunOptions, } #[test] fn test_daemoncommand_serialization() { - insta::assert_snapshot!(serde_json::to_string(&DaemonCommand::Start( - DaemonStartOptions::default() - )) - .unwrap()); insta::assert_snapshot!( - serde_json::to_string(&DaemonCommand::Start(DaemonStartOptions { - tun: TunOptions { ..TunOptions::default() } - })) - .unwrap() + serde_json::to_string(&DaemonCommand::Start(DaemonStartOptions::default())).unwrap() ); - insta::assert_snapshot!(serde_json::to_string(&DaemonCommand::ServerInfo).unwrap()); - insta::assert_snapshot!(serde_json::to_string(&DaemonCommand::Stop).unwrap()); - insta::assert_snapshot!(serde_json::to_string(&DaemonCommand::ServerConfig).unwrap()) -} + insta::assert_snapshot!( + serde_json::to_string(&DaemonCommand::ServerInfo).unwrap() + ); + insta::assert_snapshot!( + serde_json::to_string(&DaemonCommand::Stop).unwrap() + ); + insta::assert_snapshot!( + serde_json::to_string(&DaemonCommand::ServerConfig).unwrap() + ) +} \ No newline at end of file diff --git a/burrow/src/daemon/instance.rs b/burrow/src/daemon/instance.rs index 6a430c5..f807ba2 100644 --- a/burrow/src/daemon/instance.rs +++ b/burrow/src/daemon/instance.rs @@ -1,105 +1,50 @@ -use std::sync::Arc; - -use anyhow::Result; -use tokio::{sync::RwLock, task::JoinHandle}; use tracing::{debug, info, warn}; -use tun::tokio::TunInterface; - -use crate::{ - daemon::{ - command::DaemonCommand, - response::{DaemonResponse, DaemonResponseData, ServerConfig, ServerInfo}, - }, - wireguard::Interface, -}; - -enum RunState { - Running(JoinHandle>), - Idle, -} +use DaemonResponse; +use tun::TunInterface; +use crate::daemon::response::{DaemonResponseData, ServerConfig, ServerInfo}; +use super::*; pub struct DaemonInstance { rx: async_channel::Receiver, sx: async_channel::Sender, - tun_interface: Option>>, - wg_interface: Arc>, - wg_state: RunState, + tun_interface: Option, } impl DaemonInstance { - pub fn new( - rx: async_channel::Receiver, - sx: async_channel::Sender, - wg_interface: Arc>, - ) -> Self { + pub fn new(rx: async_channel::Receiver, sx: async_channel::Sender) -> Self { Self { rx, sx, - wg_interface, tun_interface: None, - wg_state: RunState::Idle, } } - pub fn set_tun_interface(&mut self, tun_interface: Arc>) { - self.tun_interface = Some(tun_interface); - } - async fn proc_command(&mut self, command: DaemonCommand) -> Result { info!("Daemon got command: {:?}", command); match command { DaemonCommand::Start(st) => { - match self.wg_state { - RunState::Running(_) => { - warn!("Got start, but tun interface already up."); - } - RunState::Idle => { - let raw = tun::TunInterface::retrieve().unwrap(); - debug!("TunInterface retrieved: {:?}", raw.name()?); - - let retrieved = TunInterface::new(raw)?; - let tun_if = Arc::new(RwLock::new(retrieved)); - // let tun_if = Arc::new(RwLock::new(TunInterface::new(st.tun.open()?)?)); - - debug!("Setting tun_interface"); - self.tun_interface = Some(tun_if.clone()); - debug!("tun_interface set: {:?}", self.tun_interface); - - debug!("Setting tun on wg_interface"); - self.wg_interface.write().await.set_tun(tun_if); - debug!("tun set on wg_interface"); - - debug!("Cloning wg_interface"); - let tmp_wg = self.wg_interface.clone(); - debug!("wg_interface cloned"); - - debug!("Spawning run task"); - let run_task = tokio::spawn(async move { - debug!("Running wg_interface"); - let twlock = tmp_wg.read().await; - debug!("wg_interface read lock acquired"); - twlock.run().await - }); - debug!("Run task spawned: {:?}", run_task); - - debug!("Setting wg_state to Running"); - self.wg_state = RunState::Running(run_task); - debug!("wg_state set to Running"); - - info!("Daemon started tun interface"); - } + if self.tun_interface.is_none() { + debug!("Daemon attempting start tun interface."); + self.tun_interface = Some(st.tun.open()?); + info!("Daemon started tun interface"); + } else { + warn!("Got start, but tun interface already up."); } Ok(DaemonResponseData::None) } - DaemonCommand::ServerInfo => match &self.tun_interface { - None => Ok(DaemonResponseData::None), - Some(ti) => { - info!("{:?}", ti); - Ok(DaemonResponseData::ServerInfo(ServerInfo::try_from( - ti.read().await.inner.get_ref(), - )?)) + DaemonCommand::ServerInfo => { + match &self.tun_interface { + None => {Ok(DaemonResponseData::None)} + Some(ti) => { + info!("{:?}", ti); + Ok( + DaemonResponseData::ServerInfo( + ServerInfo::try_from(ti)? + ) + ) + } } - }, + } DaemonCommand::Stop => { if self.tun_interface.is_some() { self.tun_interface = None; @@ -116,7 +61,6 @@ impl DaemonInstance { } pub async fn run(&mut self) -> Result<()> { - tracing::info!("BEGIN"); while let Ok(command) = self.rx.recv().await { let response = self.proc_command(command).await; info!("Daemon response: {:?}", response); diff --git a/burrow/src/daemon/mod.rs b/burrow/src/daemon/mod.rs index 1020cf7..e086452 100644 --- a/burrow/src/daemon/mod.rs +++ b/burrow/src/daemon/mod.rs @@ -1,7 +1,5 @@ -use std::{ - net::{IpAddr, Ipv4Addr, SocketAddr}, - sync::Arc, -}; +use std::net::{Ipv4Addr, SocketAddr, ToSocketAddrs}; + mod command; mod instance; @@ -12,18 +10,15 @@ use anyhow::{Error, Result}; use base64::{engine::general_purpose, Engine as _}; pub use command::{DaemonCommand, DaemonStartOptions}; use fehler::throws; -use instance::DaemonInstance; use ip_network::{IpNetwork, Ipv4Network}; +use instance::DaemonInstance; +use crate::wireguard::{StaticSecret, Peer, Interface, PublicKey}; +pub use net::DaemonClient; + #[cfg(target_vendor = "apple")] pub use net::start_srv; -pub use net::DaemonClient; -pub use response::{DaemonResponse, DaemonResponseData, ServerInfo}; -use tokio::sync::RwLock; -use crate::{ - daemon::net::listen, - wireguard::{Interface, Peer, PublicKey, StaticSecret}, -}; +pub use response::{DaemonResponseData, DaemonResponse, ServerInfo}; #[throws] fn parse_key(string: &str) -> [u8; 32] { @@ -48,13 +43,18 @@ fn parse_public_key(string: &str) -> PublicKey { pub async fn daemon_main() -> Result<()> { let (commands_tx, commands_rx) = async_channel::unbounded(); let (response_tx, response_rx) = async_channel::unbounded(); + let mut inst = DaemonInstance::new(commands_rx, response_tx); + + let mut _tun = tun::TunInterface::new()?; + _tun.set_ipv4_addr(Ipv4Addr::from([192, 168, 1, 10]))?; + _tun.set_timeout(Some(std::time::Duration::from_secs(1)))?; + let tun = tun::tokio::TunInterface::new(_tun)?; let private_key = parse_secret_key("GNqIAOCRxjl/cicZyvkvpTklgQuUmGUIEkH7IXF/sEE=")?; let public_key = parse_public_key("uy75leriJay0+oHLhRMpV+A5xAQ0hCJ+q7Ww81AOvT4=")?; let preshared_key = Some(parse_key("s7lx/mg+reVEMnGnqeyYOQkzD86n2+gYnx1M9ygi08k=")?); - let endpoint = SocketAddr::new(IpAddr::V4(Ipv4Addr::new(198, 18, 6, 180)), 51820); // DNS lookup under macos fails, somehow - - let iface = Interface::new(vec![Peer { + let endpoint = "wg.burrow.rs:51820".to_socket_addrs()?.next().unwrap(); + let iface = Interface::new(tun, vec![Peer { endpoint, private_key, public_key, @@ -62,25 +62,6 @@ pub async fn daemon_main() -> Result<()> { allowed_ips: vec![IpNetwork::V4(Ipv4Network::DEFAULT_ROUTE)], }])?; - let mut inst: DaemonInstance = - DaemonInstance::new(commands_rx, response_tx, Arc::new(RwLock::new(iface))); - - tracing::info!("Starting daemon jobs..."); - - let inst_job = tokio::spawn(async move { - let res = inst.run().await; - if let Err(e) = res { - tracing::error!("Error when running instance: {}", e); - } - }); - - let listen_job = tokio::spawn(async move { - let res = listen(commands_tx, response_rx).await; - if let Err(e) = res { - tracing::error!("Error when listening: {}", e); - } - }); - - tokio::try_join!(inst_job, listen_job).map(|_| ()); + iface.run().await; Ok(()) } diff --git a/burrow/src/daemon/net/apple.rs b/burrow/src/daemon/net/apple.rs index b84ec08..e53bdaa 100644 --- a/burrow/src/daemon/net/apple.rs +++ b/burrow/src/daemon/net/apple.rs @@ -1,13 +1,10 @@ use std::thread; - use tokio::runtime::Runtime; -use tracing::{error, info}; - +use tracing::error; use crate::daemon::{daemon_main, DaemonClient}; #[no_mangle] -pub extern "C" fn start_srv() { - info!("Rust: Starting server"); +pub extern "C" fn start_srv(){ let _handle = thread::spawn(move || { let rt = Runtime::new().unwrap(); rt.block_on(async { @@ -19,12 +16,9 @@ pub extern "C" fn start_srv() { let rt = Runtime::new().unwrap(); rt.block_on(async { loop { - match DaemonClient::new().await { - Ok(_) => break, - Err(_e) => { - // error!("Error when connecting to daemon: {}", e) - } + if let Ok(_) = DaemonClient::new().await{ + break } } }); -} +} \ No newline at end of file diff --git a/burrow/src/daemon/net/mod.rs b/burrow/src/daemon/net/mod.rs index d369f40..e5865a3 100644 --- a/burrow/src/daemon/net/mod.rs +++ b/burrow/src/daemon/net/mod.rs @@ -29,3 +29,4 @@ pub struct DaemonRequest { pub id: u32, pub command: DaemonCommand, } + diff --git a/burrow/src/daemon/net/systemd.rs b/burrow/src/daemon/net/systemd.rs index 8a2b29c..1a41f76 100644 --- a/burrow/src/daemon/net/systemd.rs +++ b/burrow/src/daemon/net/systemd.rs @@ -1,23 +1,13 @@ -pub async fn listen( - cmd_tx: async_channel::Sender, - rsp_rx: async_channel::Receiver, -) -> Result<()> { - if !libsystemd::daemon::booted() - || listen_with_systemd(cmd_tx.clone(), rsp_rx.clone()) - .await - .is_err() - { +pub async fn listen(cmd_tx: async_channel::Sender, rsp_rx: async_channel::Receiver) -> Result<()> { + if !libsystemd::daemon::booted() || listen_with_systemd(cmd_tx.clone(), rsp_rx.clone()).await.is_err() { unix::listen(cmd_tx, rsp_rx).await?; } Ok(()) } -async fn listen_with_systemd( - cmd_tx: async_channel::Sender, - rsp_rx: async_channel::Receiver, -) -> Result<()> { +async fn listen_with_systemd(cmd_tx: async_channel::Sender, rsp_rx: async_channel::Receiver) -> Result<()> { let fds = libsystemd::activation::receive_descriptors(false)?; - super::unix::listen_with_optional_fd(cmd_tx, rsp_rx, Some(fds[0].clone().into_raw_fd())).await + super::unix::listen_with_optional_fd(cmd_tx, rsp_rx,Some(fds[0].clone().into_raw_fd())).await } pub type DaemonClient = unix::DaemonClient; diff --git a/burrow/src/daemon/net/unix.rs b/burrow/src/daemon/net/unix.rs index 7ce8992..928473b 100644 --- a/burrow/src/daemon/net/unix.rs +++ b/burrow/src/daemon/net/unix.rs @@ -1,20 +1,22 @@ +use super::*; use std::{ - io, - os::{ + ascii, io, os::{ fd::{FromRawFd, RawFd}, unix::net::UnixListener as StdUnixListener, }, - path::{Path, PathBuf}, -}; + path::Path}; +use std::hash::Hash; +use std::path::PathBuf; +use anyhow::anyhow; +use log::log; +use tracing::info; -use anyhow::{anyhow, Result}; +use anyhow::Result; use tokio::{ io::{AsyncBufReadExt, AsyncWriteExt, BufReader}, net::{UnixListener, UnixStream}, }; -use tracing::{debug, info}; - -use super::*; +use tracing::debug; use crate::daemon::{DaemonCommand, DaemonResponse, DaemonResponseData}; #[cfg(not(target_vendor = "apple"))] @@ -24,33 +26,28 @@ const UNIX_SOCKET_PATH: &str = "/run/burrow.sock"; const UNIX_SOCKET_PATH: &str = "burrow.sock"; #[cfg(target_os = "macos")] -fn fetch_socket_path() -> Option { +fn fetch_socket_path() -> Option{ let tries = vec![ "burrow.sock".to_string(), - format!( - "{}/Library/Containers/com.hackclub.burrow.network/Data/burrow.sock", - std::env::var("HOME").unwrap_or_default() - ) - .to_string(), + format!("{}/Library/Containers/com.hackclub.burrow.network/Data/burrow.sock", + std::env::var("HOME").unwrap_or_default()) + .to_string(), ]; - for path in tries { + for path in tries{ let path = PathBuf::from(path); - if path.exists() { - return Some(path) + if path.exists(){ + return Some(path); } } None } #[cfg(not(target_os = "macos"))] -fn fetch_socket_path() -> Option { +fn fetch_socket_path() -> Option{ Some(Path::new(UNIX_SOCKET_PATH).to_path_buf()) } -pub async fn listen( - cmd_tx: async_channel::Sender, - rsp_rx: async_channel::Receiver, -) -> Result<()> { +pub async fn listen(cmd_tx: async_channel::Sender, rsp_rx: async_channel::Receiver) -> Result<()> { listen_with_optional_fd(cmd_tx, rsp_rx, None).await } @@ -72,12 +69,14 @@ pub(crate) async fn listen_with_optional_fd( listener } else { // Won't help all that much, if we use the async version of fs. - if let Some(par) = path.parent() { - std::fs::create_dir_all(par)?; + if let Some(par) = path.parent(){ + std::fs::create_dir_all( + par + )?; } - match std::fs::remove_file(path) { - Err(e) if e.kind() == io::ErrorKind::NotFound => Ok(()), - stuff => stuff, + match std::fs::remove_file(path){ + Err(e) if e.kind()==io::ErrorKind::NotFound => {Ok(())} + stuff => stuff }?; info!("Relative path: {}", path.to_string_lossy()); UnixListener::bind(path)? @@ -99,18 +98,18 @@ pub(crate) async fn listen_with_optional_fd( while let Ok(Some(line)) = lines.next_line().await { info!("Got line: {}", line); debug!("Line raw data: {:?}", line.as_bytes()); - let mut res: DaemonResponse = DaemonResponseData::None.into(); + let mut res : DaemonResponse = DaemonResponseData::None.into(); let req = match serde_json::from_str::(&line) { Ok(req) => Some(req), Err(e) => { res.result = Err(e.to_string()); - tracing::error!("Failed to parse request: {}", e); None } }; let mut res = serde_json::to_string(&res).unwrap(); res.push('\n'); + if let Some(req) = req { cmd_tx.send(req.command).await.unwrap(); let res = rsp_rxc.recv().await.unwrap().with_id(req.id); @@ -118,8 +117,6 @@ pub(crate) async fn listen_with_optional_fd( retres.push('\n'); info!("Sending response: {}", retres); write_stream.write_all(retres.as_bytes()).await.unwrap(); - } else { - write_stream.write_all(res.as_bytes()).await.unwrap(); } } }); @@ -132,7 +129,8 @@ pub struct DaemonClient { impl DaemonClient { pub async fn new() -> Result { - let path = fetch_socket_path().ok_or(anyhow!("Failed to find socket path"))?; + let path = fetch_socket_path() + .ok_or(anyhow!("Failed to find socket path"))?; // debug!("found path: {:?}", path); let connection = UnixStream::connect(path).await?; debug!("connected to socket"); diff --git a/burrow/src/daemon/net/windows.rs b/burrow/src/daemon/net/windows.rs index c734689..3f9d513 100644 --- a/burrow/src/daemon/net/windows.rs +++ b/burrow/src/daemon/net/windows.rs @@ -1,9 +1,6 @@ use super::*; -pub async fn listen( - _cmd_tx: async_channel::Sender, - _rsp_rx: async_channel::Receiver, -) -> Result<()> { +pub async fn listen(_cmd_tx: async_channel::Sender, _rsp_rx: async_channel::Receiver) -> Result<()> { unimplemented!("This platform does not currently support daemon mode.") } diff --git a/burrow/src/daemon/response.rs b/burrow/src/daemon/response.rs index 4bebe14..da47150 100644 --- a/burrow/src/daemon/response.rs +++ b/burrow/src/daemon/response.rs @@ -1,3 +1,4 @@ +use anyhow::anyhow; use schemars::JsonSchema; use serde::{Deserialize, Serialize}; use tun::TunInterface; @@ -6,27 +7,30 @@ use tun::TunInterface; pub struct DaemonResponse { // Error types can't be serialized, so this is the second best option. pub result: Result, - pub id: u32, + pub id: u32 } -impl DaemonResponse { - pub fn new(result: Result) -> Self { - Self { +impl DaemonResponse{ + pub fn new(result: Result) -> Self{ + Self{ result: result.map_err(|e| e.to_string()), - id: 0, + id: 0 } } } -impl Into for DaemonResponseData { - fn into(self) -> DaemonResponse { +impl Into for DaemonResponseData{ + fn into(self) -> DaemonResponse{ DaemonResponse::new(Ok::(self)) } } -impl DaemonResponse { - pub fn with_id(self, id: u32) -> Self { - Self { id, ..self } +impl DaemonResponse{ + pub fn with_id(self, id: u32) -> Self{ + Self { + id, + ..self + } } } @@ -34,22 +38,24 @@ impl DaemonResponse { pub struct ServerInfo { pub name: Option, pub ip: Option, - pub mtu: Option, + pub mtu: Option } -impl TryFrom<&TunInterface> for ServerInfo { +impl TryFrom<&TunInterface> for ServerInfo{ type Error = anyhow::Error; - #[cfg(any(target_os = "linux", target_vendor = "apple"))] + #[cfg(any(target_os="linux",target_vendor="apple"))] fn try_from(server: &TunInterface) -> anyhow::Result { - Ok(ServerInfo { - name: server.name().ok(), - ip: server.ipv4_addr().ok().map(|ip| ip.to_string()), - mtu: server.mtu().ok(), - }) + Ok( + ServerInfo{ + name: server.name().ok(), + ip: server.ipv4_addr().ok().map(|ip| ip.to_string()), + mtu: server.mtu().ok() + } + ) } - #[cfg(not(any(target_os = "linux", target_vendor = "apple")))] + #[cfg(not(any(target_os="linux",target_vendor="apple")))] fn try_from(server: &TunInterface) -> anyhow::Result { Err(anyhow!("Not implemented in this platform")) } @@ -59,55 +65,45 @@ impl TryFrom<&TunInterface> for ServerInfo { pub struct ServerConfig { pub address: Option, pub name: Option, - pub mtu: Option, + pub mtu: Option } impl Default for ServerConfig { fn default() -> Self { - Self { - address: Some("10.13.13.2".to_string()), // Dummy remote address + Self{ + address: Some("10.0.0.1".to_string()), // Dummy remote address name: None, - mtu: None, + mtu: None } } } #[derive(Clone, Debug, Serialize, Deserialize, JsonSchema)] -pub enum DaemonResponseData { +pub enum DaemonResponseData{ ServerInfo(ServerInfo), ServerConfig(ServerConfig), - None, + None } #[test] -fn test_response_serialization() -> anyhow::Result<()> { - insta::assert_snapshot!(serde_json::to_string(&DaemonResponse::new(Ok::< - DaemonResponseData, - String, - >( - DaemonResponseData::None - )))?); - insta::assert_snapshot!(serde_json::to_string(&DaemonResponse::new(Ok::< - DaemonResponseData, - String, - >( - DaemonResponseData::ServerInfo(ServerInfo { +fn test_response_serialization() -> anyhow::Result<()>{ + insta::assert_snapshot!( + serde_json::to_string(&DaemonResponse::new(Ok::(DaemonResponseData::None)))? + ); + insta::assert_snapshot!( + serde_json::to_string(&DaemonResponse::new(Ok::(DaemonResponseData::ServerInfo(ServerInfo{ name: Some("burrow".to_string()), ip: None, mtu: Some(1500) - }) - )))?); - insta::assert_snapshot!(serde_json::to_string(&DaemonResponse::new(Err::< - DaemonResponseData, - String, - >( - "error".to_string() - )))?); - insta::assert_snapshot!(serde_json::to_string(&DaemonResponse::new(Ok::< - DaemonResponseData, - String, - >( - DaemonResponseData::ServerConfig(ServerConfig::default()) - )))?); + }))))? + ); + insta::assert_snapshot!( + serde_json::to_string(&DaemonResponse::new(Err::("error".to_string())))? + ); + insta::assert_snapshot!( + serde_json::to_string(&DaemonResponse::new(Ok::(DaemonResponseData::ServerConfig( + ServerConfig::default() + ))))? + ); Ok(()) -} +} \ No newline at end of file diff --git a/burrow/src/daemon/snapshots/burrow__daemon__command__daemoncommand_serialization-2.snap b/burrow/src/daemon/snapshots/burrow__daemon__command__daemoncommand_serialization-2.snap index 289851f..80b9e24 100644 --- a/burrow/src/daemon/snapshots/burrow__daemon__command__daemoncommand_serialization-2.snap +++ b/burrow/src/daemon/snapshots/burrow__daemon__command__daemoncommand_serialization-2.snap @@ -1,5 +1,5 @@ --- source: burrow/src/daemon/command.rs -expression: "serde_json::to_string(&DaemonCommand::Start(DaemonStartOptions {\n tun: TunOptions { seek_utun: true, ..TunOptions::default() },\n })).unwrap()" +expression: "serde_json::to_string(&DaemonCommand::ServerInfo).unwrap()" --- -{"Start":{"tun":{"name":null,"no_pi":false,"tun_excl":false,"seek_utun":true,"address":null}}} +"ServerInfo" diff --git a/burrow/src/daemon/snapshots/burrow__daemon__command__daemoncommand_serialization-3.snap b/burrow/src/daemon/snapshots/burrow__daemon__command__daemoncommand_serialization-3.snap index 80b9e24..8dc1b8b 100644 --- a/burrow/src/daemon/snapshots/burrow__daemon__command__daemoncommand_serialization-3.snap +++ b/burrow/src/daemon/snapshots/burrow__daemon__command__daemoncommand_serialization-3.snap @@ -1,5 +1,5 @@ --- source: burrow/src/daemon/command.rs -expression: "serde_json::to_string(&DaemonCommand::ServerInfo).unwrap()" +expression: "serde_json::to_string(&DaemonCommand::Stop).unwrap()" --- -"ServerInfo" +"Stop" diff --git a/burrow/src/daemon/snapshots/burrow__daemon__command__daemoncommand_serialization-4.snap b/burrow/src/daemon/snapshots/burrow__daemon__command__daemoncommand_serialization-4.snap index 8dc1b8b..9334ece 100644 --- a/burrow/src/daemon/snapshots/burrow__daemon__command__daemoncommand_serialization-4.snap +++ b/burrow/src/daemon/snapshots/burrow__daemon__command__daemoncommand_serialization-4.snap @@ -1,5 +1,5 @@ --- source: burrow/src/daemon/command.rs -expression: "serde_json::to_string(&DaemonCommand::Stop).unwrap()" +expression: "serde_json::to_string(&DaemonCommand::ServerConfig).unwrap()" --- -"Stop" +"ServerConfig" diff --git a/burrow/src/daemon/snapshots/burrow__daemon__command__daemoncommand_serialization-5.snap b/burrow/src/daemon/snapshots/burrow__daemon__command__daemoncommand_serialization-5.snap deleted file mode 100644 index 9334ece..0000000 --- a/burrow/src/daemon/snapshots/burrow__daemon__command__daemoncommand_serialization-5.snap +++ /dev/null @@ -1,5 +0,0 @@ ---- -source: burrow/src/daemon/command.rs -expression: "serde_json::to_string(&DaemonCommand::ServerConfig).unwrap()" ---- -"ServerConfig" diff --git a/burrow/src/daemon/snapshots/burrow__daemon__command__daemoncommand_serialization.snap b/burrow/src/daemon/snapshots/burrow__daemon__command__daemoncommand_serialization.snap index ff32838..2f8af66 100644 --- a/burrow/src/daemon/snapshots/burrow__daemon__command__daemoncommand_serialization.snap +++ b/burrow/src/daemon/snapshots/burrow__daemon__command__daemoncommand_serialization.snap @@ -2,4 +2,4 @@ source: burrow/src/daemon/command.rs expression: "serde_json::to_string(&DaemonCommand::Start(DaemonStartOptions::default())).unwrap()" --- -{"Start":{"tun":{"name":null,"no_pi":false,"tun_excl":false,"seek_utun":false,"address":null}}} +{"Start":{"tun":{"name":null,"no_pi":null,"tun_excl":null}}} diff --git a/burrow/src/daemon/snapshots/burrow__daemon__response__response_serialization-4.snap b/burrow/src/daemon/snapshots/burrow__daemon__response__response_serialization-4.snap index 9752ebc..95f9e7b 100644 --- a/burrow/src/daemon/snapshots/burrow__daemon__response__response_serialization-4.snap +++ b/burrow/src/daemon/snapshots/burrow__daemon__response__response_serialization-4.snap @@ -2,4 +2,4 @@ source: burrow/src/daemon/response.rs expression: "serde_json::to_string(&DaemonResponse::new(Ok::(DaemonResponseData::ServerConfig(ServerConfig::default()))))?" --- -{"result":{"Ok":{"ServerConfig":{"address":"10.13.13.2","name":null,"mtu":null}}},"id":0} +{"result":{"Ok":{"ServerConfig":{"address":"10.0.0.1","name":null,"mtu":null}}},"id":0} diff --git a/burrow/src/ensureroot.rs b/burrow/src/ensureroot.rs new file mode 100644 index 0000000..b7c0757 --- /dev/null +++ b/burrow/src/ensureroot.rs @@ -0,0 +1,40 @@ +use tracing::instrument; + +// Check capabilities on Linux +#[cfg(target_os = "linux")] +#[instrument] +pub fn ensure_root() { + use caps::{has_cap, CapSet, Capability}; + + let cap_net_admin = Capability::CAP_NET_ADMIN; + if let Ok(has_cap) = has_cap(None, CapSet::Effective, cap_net_admin) { + if !has_cap { + eprintln!( + "This action needs the CAP_NET_ADMIN permission. Did you mean to run it as root?" + ); + std::process::exit(77); + } + } else { + eprintln!("Failed to check capabilities. Please file a bug report!"); + std::process::exit(71); + } +} + +// Check for root user on macOS +#[cfg(target_vendor = "apple")] +#[instrument] +pub fn ensure_root() { + use nix::unistd::Uid; + + let current_uid = Uid::current(); + if !current_uid.is_root() { + eprintln!("This action must be run as root!"); + std::process::exit(77); + } +} + +#[cfg(target_family = "windows")] +#[instrument] +pub fn ensure_root() { + todo!() +} diff --git a/burrow/src/lib.rs b/burrow/src/lib.rs index 9df60f0..030022d 100644 --- a/burrow/src/lib.rs +++ b/burrow/src/lib.rs @@ -1,16 +1,44 @@ +pub mod ensureroot; pub mod wireguard; -mod daemon; -pub use daemon::{ - DaemonCommand, - DaemonResponse, - DaemonResponseData, - DaemonStartOptions, - ServerInfo, +use anyhow::Result; + +#[cfg(any(target_os = "linux", target_vendor = "apple"))] +use std::{ + mem, + os::fd::{AsRawFd, FromRawFd}, }; +use tun::TunInterface; + +// TODO Separate start and retrieve functions + +mod daemon; +pub use daemon::{DaemonCommand, DaemonResponseData, DaemonStartOptions, DaemonResponse, ServerInfo}; + #[cfg(target_vendor = "apple")] mod apple; #[cfg(target_vendor = "apple")] pub use apple::*; + +#[cfg(any(target_os = "linux", target_vendor = "apple"))] +#[no_mangle] +pub extern "C" fn retrieve() -> i32 { + let iface2 = (1..100) + .filter_map(|i| { + let iface = unsafe { TunInterface::from_raw_fd(i) }; + match iface.name() { + Ok(_name) => Some(iface), + Err(_) => { + mem::forget(iface); + None + } + } + }) + .next(); + match iface2 { + Some(iface) => iface.as_raw_fd(), + None => -1, + } +} diff --git a/burrow/src/main.rs b/burrow/src/main.rs index 5003277..2e89a48 100644 --- a/burrow/src/main.rs +++ b/burrow/src/main.rs @@ -1,19 +1,23 @@ +use std::mem; +#[cfg(any(target_os = "linux", target_vendor = "apple"))] +use std::os::fd::FromRawFd; + use anyhow::{Context, Result}; #[cfg(any(target_os = "linux", target_vendor = "apple"))] +use burrow::retrieve; use clap::{Args, Parser, Subcommand}; -use tracing::instrument; +use tracing::{instrument, Level}; use tracing_log::LogTracer; use tracing_oslog::OsLogger; -use tracing_subscriber::{prelude::*, EnvFilter, FmtSubscriber}; +use tracing_subscriber::{prelude::*, FmtSubscriber, EnvFilter}; #[cfg(any(target_os = "linux", target_vendor = "apple"))] use tun::TunInterface; + mod daemon; mod wireguard; use daemon::{DaemonClient, DaemonCommand, DaemonStartOptions}; -use tun::TunOptions; - use crate::daemon::DaemonResponseData; #[derive(Parser)] @@ -61,9 +65,7 @@ struct DaemonArgs {} async fn try_start() -> Result<()> { let mut client = DaemonClient::new().await?; client - .send_command(DaemonCommand::Start(DaemonStartOptions { - tun: TunOptions::new().address("10.13.13.2"), - })) + .send_command(DaemonCommand::Start(DaemonStartOptions::default())) .await .map(|_| ()) } @@ -85,8 +87,9 @@ async fn try_retrieve() -> Result<()> { } } - let iface2 = TunInterface::retrieve().ok_or(anyhow::anyhow!("No interface found"))?; - tracing::info!("{:?}", iface2); + burrow::ensureroot::ensure_root(); + let iface2 = retrieve(); + tracing::info!("{}", iface2); Ok(()) } @@ -101,10 +104,9 @@ async fn initialize_tracing() -> Result<()> { FmtSubscriber::builder() .with_line_number(true) .with_env_filter(EnvFilter::from_default_env()) - .finish(), + .finish() ); - tracing::subscriber::set_global_default(logger) - .context("Failed to set the global tracing subscriber")?; + tracing::subscriber::set_global_default(logger).context("Failed to set the global tracing subscriber")?; } } @@ -119,7 +121,7 @@ async fn try_stop() -> Result<()> { } #[cfg(any(target_os = "linux", target_vendor = "apple"))] -async fn try_serverinfo() -> Result<()> { +async fn try_serverinfo() -> Result<()>{ let mut client = DaemonClient::new().await?; let res = client.send_command(DaemonCommand::ServerInfo).await?; match res.result { @@ -129,9 +131,7 @@ async fn try_serverinfo() -> Result<()> { Ok(DaemonResponseData::None) => { println!("Server not started.") } - Ok(res) => { - println!("Unexpected Response: {:?}", res) - } + Ok(res) => {println!("Unexpected Response: {:?}", res)} Err(e) => { println!("Error when retrieving from server: {}", e) } @@ -140,7 +140,7 @@ async fn try_serverinfo() -> Result<()> { } #[cfg(any(target_os = "linux", target_vendor = "apple"))] -async fn try_serverconfig() -> Result<()> { +async fn try_serverconfig() -> Result<()>{ let mut client = DaemonClient::new().await?; let res = client.send_command(DaemonCommand::ServerConfig).await?; match res.result { @@ -150,9 +150,7 @@ async fn try_serverconfig() -> Result<()> { Ok(DaemonResponseData::None) => { println!("Server not started.") } - Ok(res) => { - println!("Unexpected Response: {:?}", res) - } + Ok(res) => {println!("Unexpected Response: {:?}", res)} Err(e) => { println!("Error when retrieving from server: {}", e) } @@ -203,8 +201,12 @@ async fn main() -> Result<()> { try_stop().await?; } Commands::Daemon(_) => daemon::daemon_main().await?, - Commands::ServerInfo => try_serverinfo().await?, - Commands::ServerConfig => try_serverconfig().await?, + Commands::ServerInfo => { + try_serverinfo().await? + } + Commands::ServerConfig => { + try_serverconfig().await? + } } Ok(()) diff --git a/burrow/src/wireguard/iface.rs b/burrow/src/wireguard/iface.rs index 3d1823b..7f6473c 100755 --- a/burrow/src/wireguard/iface.rs +++ b/burrow/src/wireguard/iface.rs @@ -1,15 +1,23 @@ -use std::{net::IpAddr, sync::Arc, time::Duration}; +use std::{net::IpAddr, rc::Rc}; +use std::sync::Arc; +use std::time::Duration; use anyhow::Error; use async_trait::async_trait; use fehler::throws; -use futures::{future::join_all, FutureExt}; use ip_network_table::IpNetworkTable; -use tokio::{sync::RwLock, task::JoinHandle, time::timeout}; -use tracing::{debug, error}; +use log::log; +use tokio::{ + join, + sync::{Mutex, RwLock}, + task::{self, JoinHandle}, +}; use tun::tokio::TunInterface; +use futures::future::join_all; +use futures::FutureExt; +use tokio::time::timeout; -use super::{noise::Tunnel, Peer, PeerPcb}; +use super::{noise::Tunnel, pcb, Peer, PeerPcb}; #[async_trait] pub trait PacketInterface { @@ -29,7 +37,7 @@ impl PacketInterface for tun::tokio::TunInterface { } struct IndexedPcbs { - pcbs: Vec>, + pcbs: Vec>>, allowed_ips: IpNetworkTable, } @@ -46,7 +54,7 @@ impl IndexedPcbs { for allowed_ip in pcb.allowed_ips.iter() { self.allowed_ips.insert(allowed_ip.clone(), idx); } - self.pcbs.insert(idx, Arc::new(pcb)); + self.pcbs.insert(idx, Arc::new(RwLock::new(pcb))); } pub fn find(&self, addr: IpAddr) -> Option { @@ -55,7 +63,7 @@ impl IndexedPcbs { } pub async fn connect(&self, idx: usize, handle: JoinHandle<()>) { - self.pcbs[idx].handle.write().await.replace(handle); + self.pcbs[idx].write().await.handle = Some(handle); } } @@ -69,124 +77,115 @@ impl FromIterator for IndexedPcbs { } pub struct Interface { - tun: Option>>, + tun: Arc>, pcbs: Arc, } impl Interface { #[throws] - pub fn new>(peers: I) -> Self { - let pcbs: IndexedPcbs = peers + pub fn new>(tun: TunInterface, peers: I) -> Self { + let mut pcbs: IndexedPcbs = peers .into_iter() .map(|peer| PeerPcb::new(peer)) .collect::>()?; + let tun = Arc::new(RwLock::new(tun)); let pcbs = Arc::new(pcbs); - Self { pcbs, tun: None } + Self { tun, pcbs } } - pub fn set_tun(&mut self, tun: Arc>) { - self.tun = Some(tun); - } - - pub async fn run(&self) -> anyhow::Result<()> { - debug!("RUN: starting interface"); + pub async fn run(self) { let pcbs = self.pcbs.clone(); - let tun = self - .tun - .clone() - .ok_or(anyhow::anyhow!("tun interface does not exist"))?; + let tun = self.tun.clone(); log::info!("starting interface"); let outgoing = async move { loop { - // tracing::debug!("starting loop..."); + log::debug!("starting loop..."); let mut buf = [0u8; 3000]; let src = { - let src = match timeout( - Duration::from_millis(10), - tun.read().await.recv(&mut buf[..]), - ) - .await - { + log::debug!("awaiting read..."); + let src = match timeout(Duration::from_secs(2), tun.write().await.recv(&mut buf[..])).await { Ok(Ok(len)) => &buf[..len], - Ok(Err(e)) => { - error!("failed to read from interface: {}", e); + Ok(Err(e)) => {continue} + Err(_would_block) => { continue } - Err(_would_block) => continue, }; - debug!("read {} bytes from interface", src.len()); - debug!("bytes: {:?}", src); + log::debug!("read {} bytes from interface", src.len()); + log::debug!("bytes: {:?}", src); src }; + let dst_addr = match Tunnel::dst_address(src) { Some(addr) => addr, None => { - tracing::debug!("no destination found"); + log::debug!("no destination found"); continue - } + }, }; - tracing::debug!("dst_addr: {}", dst_addr); + log::debug!("dst_addr: {}", dst_addr); let Some(idx) = pcbs.find(dst_addr) else { continue }; - tracing::debug!("found peer:{}", idx); + log::debug!("found peer:{}", idx); - match pcbs.pcbs[idx].send(src).await { + match pcbs.pcbs[idx].read().await.send(src).await { Ok(..) => { - let addr = pcbs.pcbs[idx].endpoint; - tracing::debug!("sent packet to peer {}", addr); + log::debug!("sent packet to peer {}", dst_addr); } Err(e) => { log::error!("failed to send packet {}", e); continue - } + }, }; + + // let mut buf = [0u8; 3000]; + // match pcbs.pcbs[idx].read().await.recv(&mut buf).await { + // Ok(len) => log::debug!("received {} bytes from peer {}", len, dst_addr), + // Err(e) => { + // log::error!("failed to receive packet {}", e); + // continue + // }, + // } } }; let mut tsks = vec![]; - let tun = self - .tun - .clone() - .ok_or(anyhow::anyhow!("tun interface does not exist"))?; + let tun = self.tun.clone(); let outgoing = tokio::task::spawn(outgoing); tsks.push(outgoing); - debug!("preparing to spawn read tasks"); - { - let pcbs = &self.pcbs; - for i in 0..pcbs.pcbs.len() { - debug!("spawning read task for peer {}", i); - let pcb = pcbs.pcbs[i].clone(); + let pcbs = self.pcbs; + for i in 0..pcbs.pcbs.len(){ + let mut pcb = pcbs.pcbs[i].clone(); let tun = tun.clone(); let tsk = async move { - if let Err(e) = pcb.open_if_closed().await { - log::error!("failed to open pcb: {}", e); - return + { + let r1 = pcb.write().await.open_if_closed().await; + if let Err(e) = r1 { + log::error!("failed to open pcb: {}", e); + return + } } - let r2 = pcb.run(tun).await; + let r2 = pcb.read().await.run().await; if let Err(e) = r2 { log::error!("failed to run pcb: {}", e); return } else { - tracing::debug!("pcb ran successfully"); + log::debug!("pcb ran successfully"); } }; - debug!("task made.."); tsks.push(tokio::spawn(tsk)); } - debug!("spawned read tasks"); + log::debug!("spawned read tasks"); } - debug!("preparing to join.."); + log::debug!("preparing to join.."); join_all(tsks).await; - debug!("joined!"); - Ok(()) } } diff --git a/burrow/src/wireguard/noise/handshake.rs b/burrow/src/wireguard/noise/handshake.rs index 2ec0c6a..92c456c 100755 --- a/burrow/src/wireguard/noise/handshake.rs +++ b/burrow/src/wireguard/noise/handshake.rs @@ -136,6 +136,10 @@ fn aead_chacha20_open( ) -> Result<(), WireGuardError> { let mut nonce: [u8; 12] = [0; 12]; nonce[4..].copy_from_slice(&counter.to_le_bytes()); + + log::debug!("TAG A"); + log::debug!("{:?};{:?};{:?};{:?};{}", key, data, aad, nonce, counter); + aead_chacha20_open_inner(buffer, key, nonce, data, aad) .map_err(|_| WireGuardError::InvalidAeadTag)?; Ok(()) @@ -679,6 +683,7 @@ impl Handshake { aad: &mac1[0..16], msg: packet.encrypted_cookie, }; + log::debug!("TAG B"); let plaintext = XChaCha20Poly1305::new_from_slice(&key) .unwrap() .decrypt(packet.nonce.into(), payload) diff --git a/burrow/src/wireguard/noise/mod.rs b/burrow/src/wireguard/noise/mod.rs index 3a60c22..7e2184d 100755 --- a/burrow/src/wireguard/noise/mod.rs +++ b/burrow/src/wireguard/noise/mod.rs @@ -146,7 +146,7 @@ impl Tunnel { // Checks the type, as well as the reserved zero fields let packet_type = u32::from_le_bytes(src[0..4].try_into().unwrap()); - tracing::debug!("packet_type: {}", packet_type); + log::debug!("packet_type: {}", packet_type); Ok(match (packet_type, src.len()) { (HANDSHAKE_INIT, HANDSHAKE_INIT_SZ) => Packet::HandshakeInit(HandshakeInit { diff --git a/burrow/src/wireguard/noise/rate_limiter.rs b/burrow/src/wireguard/noise/rate_limiter.rs index ff19efd..df29f93 100755 --- a/burrow/src/wireguard/noise/rate_limiter.rs +++ b/burrow/src/wireguard/noise/rate_limiter.rs @@ -6,6 +6,7 @@ use std::{ use aead::{generic_array::GenericArray, AeadInPlace, KeyInit}; use chacha20poly1305::{Key, XChaCha20Poly1305}; +use log::log; use parking_lot::Mutex; use rand_core::{OsRng, RngCore}; use ring::constant_time::verify_slices_are_equal; @@ -173,14 +174,14 @@ impl RateLimiter { dst: &'b mut [u8], ) -> Result, TunnResult<'b>> { let packet = Tunnel::parse_incoming_packet(src)?; - tracing::debug!("packet: {:?}", packet); + log::debug!("packet: {:?}", packet); // Verify and rate limit handshake messages only if let Packet::HandshakeInit(HandshakeInit { sender_idx, .. }) | Packet::HandshakeResponse(HandshakeResponse { sender_idx, .. }) = packet { - tracing::debug!("sender_idx: {}", sender_idx); - tracing::debug!("response: {:?}", packet); + log::debug!("sender_idx: {}", sender_idx); + log::debug!("response: {:?}", packet); let (msg, macs) = src.split_at(src.len() - 32); let (mac1, mac2) = macs.split_at(16); diff --git a/burrow/src/wireguard/noise/session.rs b/burrow/src/wireguard/noise/session.rs index 8988728..eb7dbef 100755 --- a/burrow/src/wireguard/noise/session.rs +++ b/burrow/src/wireguard/noise/session.rs @@ -253,7 +253,7 @@ impl Session { // check the counter without running expensive decryption self.receiving_counter_quick_check(packet.counter)?; - tracing::debug!("TAG C"); + log::debug!("TAG C"); let ret = { let mut nonce = [0u8; 12]; nonce[4..12].copy_from_slice(&packet.counter.to_le_bytes()); diff --git a/burrow/src/wireguard/pcb.rs b/burrow/src/wireguard/pcb.rs index 21b1d6e..4ec63c5 100755 --- a/burrow/src/wireguard/pcb.rs +++ b/burrow/src/wireguard/pcb.rs @@ -1,17 +1,21 @@ -use std::{ - cell::{Cell, RefCell}, - net::SocketAddr, - sync::Arc, -}; +use std::io; +use std::net::SocketAddr; +use std::rc::Rc; +use std::sync::Arc; +use std::time::Duration; use anyhow::{anyhow, Error}; use fehler::throws; use ip_network::IpNetwork; +use log::log; use rand::random; -use tokio::{net::UdpSocket, sync::RwLock, task::JoinHandle}; -use tun::tokio::TunInterface; +use tokio::{net::UdpSocket, task::JoinHandle}; +use tokio::sync::{Mutex, RwLock}; +use tokio::time::timeout; +use uuid::uuid; use super::{ + iface::PacketInterface, noise::{TunnResult, Tunnel}, Peer, }; @@ -20,101 +24,107 @@ use super::{ pub struct PeerPcb { pub endpoint: SocketAddr, pub allowed_ips: Vec, - pub handle: RwLock>>, - socket: RwLock>, + pub handle: Option>, + socket: Option, tunnel: RwLock, } impl PeerPcb { #[throws] pub fn new(peer: Peer) -> Self { - let tunnel = RwLock::new( - Tunnel::new( - peer.private_key, - peer.public_key, - peer.preshared_key, - None, - 1, - None, - ) - .map_err(|s| anyhow::anyhow!("{}", s))?, - ); + let tunnel = RwLock::new(Tunnel::new(peer.private_key, peer.public_key, peer.preshared_key, None, 1, None) + .map_err(|s| anyhow::anyhow!("{}", s))?); + Self { endpoint: peer.endpoint, allowed_ips: peer.allowed_ips, - handle: RwLock::new(None), - socket: RwLock::new(None), + handle: None, + socket: None, tunnel, } } - pub async fn open_if_closed(&self) -> Result<(), Error> { - if self.socket.read().await.is_none() { + pub async fn open_if_closed(&mut self) -> Result<(), Error> { + if self.socket.is_none() { let socket = UdpSocket::bind("0.0.0.0:0").await?; socket.connect(self.endpoint).await?; - self.socket.write().await.replace(socket); + self.socket = Some(socket); } Ok(()) } - pub async fn run(&self, tun_interface: Arc>) -> Result<(), Error> { - tracing::debug!("starting read loop for pcb... for {:?}", &self); - let rid: i32 = random(); - let mut buf: [u8; 3000] = [0u8; 3000]; - tracing::debug!("start read loop {}", rid); + pub async fn run(&self) -> Result<(), Error> { + let mut buf = [0u8; 3000]; + log::debug!("starting read loop for pcb..."); loop { - tracing::debug!("{}: waiting for packet", rid); - let guard = self.socket.read().await; - let Some(socket) = guard.as_ref() else { + tracing::debug!("waiting for packet"); + let len = self.recv(&mut buf).await?; + tracing::debug!("received {} bytes", len); + } + } + + pub async fn recv(&self, buf: &mut [u8]) -> Result { + log::debug!("starting read loop for pcb... for {:?}", &self); + let rid: i32 = random(); + log::debug!("start read loop {}", rid); + loop{ + log::debug!("{}: waiting for packet", rid); + let Some(socket) = &self.socket else { continue }; - let mut res_buf = [0; 1500]; - // tracing::debug!("{} : waiting for readability on {:?}", rid, socket); - let len = match socket.recv(&mut res_buf).await { - Ok(l) => l, + let mut res_buf = [0;1500]; + log::debug!("{} : waiting for readability on {:?}", rid, socket); + match timeout(Duration::from_secs(2), socket.readable()).await { Err(e) => { - log::error!("{}: error reading from socket: {:?}", rid, e); + log::debug!("{}: timeout waiting for readability on {:?}", rid, e); continue } + Ok(Err(e)) => { + log::debug!("{}: error waiting for readability on {:?}", rid, e); + continue + } + Ok(Ok(_)) => {} + }; + log::debug!("{}: readable!", rid); + let Ok(len) = socket.try_recv(&mut res_buf) else { + continue }; let mut res_dat = &res_buf[..len]; tracing::debug!("{}: Decapsulating {} bytes", rid, len); tracing::debug!("{:?}", &res_dat); loop { - match self - .tunnel - .write() - .await - .decapsulate(None, res_dat, &mut buf[..]) - { - TunnResult::Done => break, + match self.tunnel.write().await.decapsulate(None, res_dat, &mut buf[..]) { + TunnResult::Done => { + break; + } TunnResult::Err(e) => { tracing::error!(message = "Decapsulate error", error = ?e); - break + break; } TunnResult::WriteToNetwork(packet) => { tracing::debug!("WriteToNetwork: {:?}", packet); - self.open_if_closed().await?; socket.send(packet).await?; tracing::debug!("WriteToNetwork done"); res_dat = &[]; - continue + continue; } TunnResult::WriteToTunnelV4(packet, addr) => { tracing::debug!("WriteToTunnelV4: {:?}, {:?}", packet, addr); - tun_interface.read().await.send(packet).await?; - break - } - TunnResult::WriteToTunnelV6(packet, addr) => { - tracing::debug!("WriteToTunnelV6: {:?}, {:?}", packet, addr); - tun_interface.read().await.send(packet).await?; - break + continue; } + e => panic!("Unexpected result from decapsulate: {:?}", e), } } + return Ok(len) } } + pub async fn socket(&mut self) -> Result<&UdpSocket, Error> { + self.open_if_closed().await?; + Ok(self.socket.as_ref().expect("socket was just opened")) + } + + pub async fn send(&self, src: &[u8]) -> Result<(), Error> { let mut dst_buf = [0u8; 3000]; match self.tunnel.write().await.encapsulate(src, &mut dst_buf[..]) { @@ -123,12 +133,7 @@ impl PeerPcb { tracing::error!(message = "Encapsulate error", error = ?e) } TunnResult::WriteToNetwork(packet) => { - self.open_if_closed().await?; - let handle = self.socket.read().await; - let Some(socket) = handle.as_ref() else { - tracing::error!("No socket for peer"); - return Ok(()) - }; + let socket = &self.socket.as_ref().ok_or(anyhow!("socket not open"))?; tracing::debug!("Our Encapsulated packet: {:?}", packet); socket.send(packet).await?; } diff --git a/burrow/src/wireguard/peer.rs b/burrow/src/wireguard/peer.rs index 131b0d4..cc8a296 100755 --- a/burrow/src/wireguard/peer.rs +++ b/burrow/src/wireguard/peer.rs @@ -1,5 +1,7 @@ use std::{fmt, net::SocketAddr}; +use anyhow::Error; +use fehler::throws; use ip_network::IpNetwork; use x25519_dalek::{PublicKey, StaticSecret}; @@ -8,7 +10,7 @@ pub struct Peer { pub private_key: StaticSecret, pub public_key: PublicKey, pub allowed_ips: Vec, - pub preshared_key: Option<[u8; 32]>, + pub preshared_key: Option<[u8; 32]> } impl fmt::Debug for Peer { diff --git a/tun/build.rs b/tun/build.rs index 03ee131..8da8a40 100644 --- a/tun/build.rs +++ b/tun/build.rs @@ -26,7 +26,7 @@ async fn generate(out_dir: &std::path::Path) -> anyhow::Result<()> { println!("cargo:rerun-if-changed={}", binary_path.to_str().unwrap()); if let (Ok(..), Ok(..)) = (File::open(&bindings_path), File::open(&binary_path)) { - return Ok(()); + return Ok(()) }; let archive = download(out_dir) diff --git a/tun/src/options.rs b/tun/src/options.rs index 7c414dc..3fe5a13 100644 --- a/tun/src/options.rs +++ b/tun/src/options.rs @@ -5,48 +5,28 @@ use fehler::throws; use super::TunInterface; #[derive(Debug, Clone, Default)] -#[cfg_attr( - feature = "serde", - derive(serde::Serialize, serde::Deserialize, schemars::JsonSchema) -)] +#[cfg_attr(feature = "serde", derive(serde::Serialize, serde::Deserialize, schemars::JsonSchema))] pub struct TunOptions { /// (Windows + Linux) Name the tun interface. - pub name: Option, + pub(crate) name: Option, /// (Linux) Don't include packet information. - pub no_pi: bool, + pub(crate) no_pi: Option<()>, /// (Linux) Avoid opening an existing persistant device. - pub tun_excl: bool, - /// (Linux) The IP address of the tun interface. - pub address: Option, + pub(crate) tun_excl: Option<()>, } impl TunOptions { - pub fn new() -> Self { - Self::default() - } + pub fn new() -> Self { Self::default() } pub fn name(mut self, name: &str) -> Self { self.name = Some(name.to_owned()); self } - pub fn no_pi(mut self, enable: bool) -> Self { - self.no_pi = enable; - self - } + pub fn no_pi(mut self, enable: bool) { self.no_pi = enable.then_some(()); } - pub fn tun_excl(mut self, enable: bool) -> Self { - self.tun_excl = enable; - self - } - - pub fn address(mut self, address: impl ToString) -> Self { - self.address = Some(address.to_string()); - self - } + pub fn tun_excl(mut self, enable: bool) { self.tun_excl = enable.then_some(()); } #[throws] - pub fn open(self) -> TunInterface { - TunInterface::new_with_options(self)? - } + pub fn open(self) -> TunInterface { TunInterface::new_with_options(self)? } } diff --git a/tun/src/tokio/mod.rs b/tun/src/tokio/mod.rs index 947fb74..8d23b7b 100644 --- a/tun/src/tokio/mod.rs +++ b/tun/src/tokio/mod.rs @@ -5,14 +5,15 @@ use tracing::instrument; #[derive(Debug)] pub struct TunInterface { - pub inner: AsyncFd, + inner: AsyncFd, } impl TunInterface { #[instrument] - pub fn new(mut tun: crate::TunInterface) -> io::Result { - tun.set_nonblocking(true)?; - Ok(Self { inner: AsyncFd::new(tun)? }) + pub fn new(tun: crate::TunInterface) -> io::Result { + Ok(Self { + inner: AsyncFd::new(tun)?, + }) } #[instrument] @@ -26,17 +27,38 @@ impl TunInterface { } } - #[instrument] - pub async fn recv(&self, buf: &mut [u8]) -> io::Result { + // #[instrument] + pub async fn recv(&mut self, buf: &mut [u8]) -> io::Result { loop { - let mut guard = self.inner.readable().await?; - match guard.try_io(|inner| inner.get_ref().recv(buf)) { - Ok(result) => return result, + log::debug!("TunInterface receiving..."); + let mut guard = self.inner.readable_mut().await?; + log::debug!("Got! readable_mut"); + match guard.try_io(|inner| { + // log::debug!("Got! {:#?}", inner); + let raw_ref = (*inner).get_mut(); + // log::debug!("Got mut ref! {:#?}", raw_ref); + let recved = raw_ref.recv(buf); + // log::debug!("Got recved! {:#?}", recved); + recved + }) { + Ok(result) => { + log::debug!("HORRAY"); + return result + }, Err(_would_block) => { - tracing::debug!("WouldBlock"); + log::debug!("WouldBlock"); continue - } + }, } } } + + #[instrument] + pub async fn try_recv(&mut self, buf: &mut [u8]) -> io::Result { + let mut guard = self.inner.readable_mut().await?; + match guard.try_io(|inner| (*inner).get_mut().recv(buf)) { + Ok(result) => Ok(result.unwrap_or_default()), + Err(_would_block) => Err(io::Error::new(io::ErrorKind::WouldBlock, "WouldBlock")), + } + } } diff --git a/tun/src/unix/apple/mod.rs b/tun/src/unix/apple/mod.rs index ab08505..83dbdc1 100644 --- a/tun/src/unix/apple/mod.rs +++ b/tun/src/unix/apple/mod.rs @@ -1,8 +1,8 @@ use std::{ io::{Error, IoSlice}, - mem::{self, ManuallyDrop}, + mem, net::{Ipv4Addr, SocketAddrV4}, - os::fd::{AsRawFd, FromRawFd, RawFd}, + os::fd::{AsRawFd, RawFd}, }; use byteorder::{ByteOrder, NetworkEndian}; @@ -11,14 +11,13 @@ use libc::{c_char, iovec, writev, AF_INET, AF_INET6}; use socket2::{Domain, SockAddr, Socket, Type}; use tracing::{self, instrument}; -pub mod kern_control; -pub mod sys; +mod kern_control; +mod sys; use kern_control::SysControlSocket; pub use super::queue::TunQueue; -use super::{ifname_to_string, string_to_ifname}; -use crate::TunOptions; +use super::{ifname_to_string, string_to_ifname, TunOptions}; #[derive(Debug)] pub struct TunInterface { @@ -34,42 +33,8 @@ impl TunInterface { #[throws] #[instrument] - pub fn new_with_options(options: TunOptions) -> TunInterface { - let ti = TunInterface::connect(0)?; - ti.configure(options)?; - ti - } - - pub fn retrieve() -> Option { - (3..100) - .filter_map(|fd| unsafe { - let peer_addr = socket2::SockAddr::init(|storage, len| { - *len = mem::size_of::() as u32; - libc::getpeername(fd, storage as *mut _, len); - Ok(()) - }) - .map(|(_, addr)| (fd, addr)); - peer_addr.ok() - }) - .filter(|(_fd, addr)| { - let ctl_addr = unsafe { &*(addr.as_ptr() as *const libc::sockaddr_ctl) }; - addr.family() == libc::AF_SYSTEM as u8 - && ctl_addr.ss_sysaddr == libc::AF_SYS_CONTROL as u16 - }) - .map(|(fd, _)| { - let socket = unsafe { socket2::Socket::from_raw_fd(fd) }; - TunInterface { socket } - }) - .next() - } - - #[throws] - fn configure(&self, options: TunOptions) { - if let Some(addr) = options.address { - if let Ok(addr) = addr.parse() { - self.set_ipv4_addr(addr)?; - } - } + pub fn new_with_options(_: TunOptions) -> TunInterface { + TunInterface::connect(0)? } #[throws] diff --git a/tun/src/unix/apple/sys.rs b/tun/src/unix/apple/sys.rs index c0ea613..b4d4a6a 100644 --- a/tun/src/unix/apple/sys.rs +++ b/tun/src/unix/apple/sys.rs @@ -2,11 +2,20 @@ use std::mem; use libc::{c_char, c_int, c_short, c_uint, c_ulong, sockaddr}; pub use libc::{ - c_void, sockaddr_ctl, sockaddr_in, socklen_t, AF_SYSTEM, AF_SYS_CONTROL, IFNAMSIZ, + c_void, + sockaddr_ctl, + sockaddr_in, + socklen_t, + AF_SYSTEM, + AF_SYS_CONTROL, + IFNAMSIZ, SYSPROTO_CONTROL, }; use nix::{ - ioctl_read_bad, ioctl_readwrite, ioctl_write_ptr_bad, request_code_readwrite, + ioctl_read_bad, + ioctl_readwrite, + ioctl_write_ptr_bad, + request_code_readwrite, request_code_write, }; diff --git a/tun/src/unix/linux/mod.rs b/tun/src/unix/linux/mod.rs index b4b5b8c..90cf353 100644 --- a/tun/src/unix/linux/mod.rs +++ b/tun/src/unix/linux/mod.rs @@ -26,9 +26,7 @@ pub struct TunInterface { impl TunInterface { #[throws] #[instrument] - pub fn new() -> TunInterface { - Self::new_with_options(TunOptions::new())? - } + pub fn new() -> TunInterface { Self::new_with_options(TunOptions::new())? } #[throws] #[instrument] @@ -214,7 +212,5 @@ impl TunInterface { #[throws] #[instrument] - pub fn send(&self, buf: &[u8]) -> usize { - self.socket.send(buf)? - } + pub fn send(&self, buf: &[u8]) -> usize { self.socket.send(buf)? } } diff --git a/tun/src/unix/mod.rs b/tun/src/unix/mod.rs index 77a1158..407d425 100644 --- a/tun/src/unix/mod.rs +++ b/tun/src/unix/mod.rs @@ -1,11 +1,12 @@ use std::{ io::{Error, Read}, - mem::MaybeUninit, os::fd::{AsRawFd, FromRawFd, IntoRawFd, RawFd}, }; use tracing::instrument; +use super::TunOptions; + mod queue; #[cfg(target_vendor = "apple")] @@ -39,26 +40,21 @@ impl IntoRawFd for TunInterface { } } -unsafe fn assume_init(buf: &[MaybeUninit]) -> &[u8] { - &*(buf as *const [MaybeUninit] as *const [u8]) -} - impl TunInterface { #[throws] #[instrument] - pub fn recv(&self, buf: &mut [u8]) -> usize { - // Use IoVec to read directly into target buffer - let mut tmp_buf = [MaybeUninit::uninit(); 1500]; - let len = self.socket.recv(&mut tmp_buf)?; - let result_buf = unsafe { assume_init(&tmp_buf[4..len]) }; - buf[..len - 4].copy_from_slice(&result_buf); - len - 4 + pub fn recv(&mut self, buf: &mut [u8]) -> usize { + // there might be a more efficient way to implement this + let tmp_buf = &mut [0u8; 1500]; + let len = self.socket.read(tmp_buf)?; + buf[..len-4].copy_from_slice(&tmp_buf[4..len]); + len-4 } #[throws] #[instrument] - pub fn set_nonblocking(&mut self, nb: bool) { - self.socket.set_nonblocking(nb)?; + pub fn set_timeout(&self, timeout: Option) { + self.socket.set_read_timeout(timeout)?; } } diff --git a/tun/src/windows/mod.rs b/tun/src/windows/mod.rs index dadd53f..9b6d5ad 100644 --- a/tun/src/windows/mod.rs +++ b/tun/src/windows/mod.rs @@ -25,9 +25,7 @@ impl Debug for TunInterface { impl TunInterface { #[throws] - pub fn new() -> TunInterface { - Self::new_with_options(TunOptions::new())? - } + pub fn new() -> TunInterface { Self::new_with_options(TunOptions::new())? } #[throws] pub(crate) fn new_with_options(options: TunOptions) -> TunInterface { @@ -39,18 +37,17 @@ impl TunInterface { if handle.is_null() { unsafe { GetLastError() }.ok()? } - TunInterface { handle, name: name_owned } + TunInterface { + handle, + name: name_owned, + } } - pub fn name(&self) -> String { - self.name.clone() - } + pub fn name(&self) -> String { self.name.clone() } } impl Drop for TunInterface { - fn drop(&mut self) { - unsafe { sys::WINTUN.WintunCloseAdapter(self.handle) } - } + fn drop(&mut self) { unsafe { sys::WINTUN.WintunCloseAdapter(self.handle) } } } pub(crate) mod sys { diff --git a/tun/tests/configure.rs b/tun/tests/configure.rs index e7e2c6d..6ef597b 100644 --- a/tun/tests/configure.rs +++ b/tun/tests/configure.rs @@ -5,9 +5,7 @@ use tun::TunInterface; #[test] #[throws] -fn test_create() { - TunInterface::new()?; -} +fn test_create() { TunInterface::new()?; } #[test] #[throws]