From d012ca144c96123f603ec6dd38cfcb3294664c18 Mon Sep 17 00:00:00 2001 From: Jett Chen Date: Wed, 29 Nov 2023 21:56:57 +0800 Subject: [PATCH 01/10] Lower timeout interval --- burrow/src/daemon/mod.rs | 2 +- burrow/src/wireguard/iface.rs | 6 +++--- burrow/src/wireguard/pcb.rs | 8 ++++---- tun/src/tokio/mod.rs | 7 ++----- 4 files changed, 10 insertions(+), 13 deletions(-) diff --git a/burrow/src/daemon/mod.rs b/burrow/src/daemon/mod.rs index e086452..bfbc1a2 100644 --- a/burrow/src/daemon/mod.rs +++ b/burrow/src/daemon/mod.rs @@ -47,7 +47,7 @@ pub async fn daemon_main() -> Result<()> { let mut _tun = tun::TunInterface::new()?; _tun.set_ipv4_addr(Ipv4Addr::from([192, 168, 1, 10]))?; - _tun.set_timeout(Some(std::time::Duration::from_secs(1)))?; + _tun.set_timeout(Some(std::time::Duration::from_millis(10)))?; let tun = tun::tokio::TunInterface::new(_tun)?; let private_key = parse_secret_key("GNqIAOCRxjl/cicZyvkvpTklgQuUmGUIEkH7IXF/sEE=")?; diff --git a/burrow/src/wireguard/iface.rs b/burrow/src/wireguard/iface.rs index 7f6473c..cfd19a5 100755 --- a/burrow/src/wireguard/iface.rs +++ b/burrow/src/wireguard/iface.rs @@ -101,12 +101,12 @@ impl Interface { let outgoing = async move { loop { - log::debug!("starting loop..."); + // log::debug!("starting loop..."); let mut buf = [0u8; 3000]; let src = { - log::debug!("awaiting read..."); - let src = match timeout(Duration::from_secs(2), tun.write().await.recv(&mut buf[..])).await { + // log::debug!("awaiting read..."); + let src = match timeout(Duration::from_millis(10), tun.write().await.recv(&mut buf[..])).await { Ok(Ok(len)) => &buf[..len], Ok(Err(e)) => {continue} Err(_would_block) => { diff --git a/burrow/src/wireguard/pcb.rs b/burrow/src/wireguard/pcb.rs index 4ec63c5..840eda8 100755 --- a/burrow/src/wireguard/pcb.rs +++ b/burrow/src/wireguard/pcb.rs @@ -68,15 +68,15 @@ impl PeerPcb { let rid: i32 = random(); log::debug!("start read loop {}", rid); loop{ - log::debug!("{}: waiting for packet", rid); + // log::debug!("{}: waiting for packet", rid); let Some(socket) = &self.socket else { continue }; let mut res_buf = [0;1500]; - log::debug!("{} : waiting for readability on {:?}", rid, socket); - match timeout(Duration::from_secs(2), socket.readable()).await { + // log::debug!("{} : waiting for readability on {:?}", rid, socket); + match timeout(Duration::from_millis(10), socket.readable()).await { Err(e) => { - log::debug!("{}: timeout waiting for readability on {:?}", rid, e); + // log::debug!("{}: timeout waiting for readability on {:?}", rid, e); continue } Ok(Err(e)) => { diff --git a/tun/src/tokio/mod.rs b/tun/src/tokio/mod.rs index 8d23b7b..2ade0a1 100644 --- a/tun/src/tokio/mod.rs +++ b/tun/src/tokio/mod.rs @@ -30,15 +30,12 @@ impl TunInterface { // #[instrument] pub async fn recv(&mut self, buf: &mut [u8]) -> io::Result { loop { - log::debug!("TunInterface receiving..."); + // log::debug!("TunInterface receiving..."); let mut guard = self.inner.readable_mut().await?; - log::debug!("Got! readable_mut"); + // log::debug!("Got! readable_mut"); match guard.try_io(|inner| { - // log::debug!("Got! {:#?}", inner); let raw_ref = (*inner).get_mut(); - // log::debug!("Got mut ref! {:#?}", raw_ref); let recved = raw_ref.recv(buf); - // log::debug!("Got recved! {:#?}", recved); recved }) { Ok(result) => { From 7d0c0250c5b1f80edde35721de914544d824557d Mon Sep 17 00:00:00 2001 From: Jett Chen Date: Sun, 3 Dec 2023 01:27:06 +0800 Subject: [PATCH 02/10] Fix async problem remove timeouts --- burrow/src/daemon/mod.rs | 2 +- burrow/src/wireguard/pcb.rs | 14 +++----------- tun/src/unix/mod.rs | 4 ++-- 3 files changed, 6 insertions(+), 14 deletions(-) diff --git a/burrow/src/daemon/mod.rs b/burrow/src/daemon/mod.rs index bfbc1a2..719bab5 100644 --- a/burrow/src/daemon/mod.rs +++ b/burrow/src/daemon/mod.rs @@ -47,7 +47,7 @@ pub async fn daemon_main() -> Result<()> { let mut _tun = tun::TunInterface::new()?; _tun.set_ipv4_addr(Ipv4Addr::from([192, 168, 1, 10]))?; - _tun.set_timeout(Some(std::time::Duration::from_millis(10)))?; + _tun.set_nonblocking(true)?; let tun = tun::tokio::TunInterface::new(_tun)?; let private_key = parse_secret_key("GNqIAOCRxjl/cicZyvkvpTklgQuUmGUIEkH7IXF/sEE=")?; diff --git a/burrow/src/wireguard/pcb.rs b/burrow/src/wireguard/pcb.rs index 840eda8..2e467b1 100755 --- a/burrow/src/wireguard/pcb.rs +++ b/burrow/src/wireguard/pcb.rs @@ -74,20 +74,12 @@ impl PeerPcb { }; let mut res_buf = [0;1500]; // log::debug!("{} : waiting for readability on {:?}", rid, socket); - match timeout(Duration::from_millis(10), socket.readable()).await { + let len = match socket.recv(&mut res_buf).await { + Ok(l) => {l} Err(e) => { - // log::debug!("{}: timeout waiting for readability on {:?}", rid, e); + log::error!("{}: error reading from socket: {:?}", rid, e); continue } - Ok(Err(e)) => { - log::debug!("{}: error waiting for readability on {:?}", rid, e); - continue - } - Ok(Ok(_)) => {} - }; - log::debug!("{}: readable!", rid); - let Ok(len) = socket.try_recv(&mut res_buf) else { - continue }; let mut res_dat = &res_buf[..len]; tracing::debug!("{}: Decapsulating {} bytes", rid, len); diff --git a/tun/src/unix/mod.rs b/tun/src/unix/mod.rs index 407d425..bd9ffb4 100644 --- a/tun/src/unix/mod.rs +++ b/tun/src/unix/mod.rs @@ -53,8 +53,8 @@ impl TunInterface { #[throws] #[instrument] - pub fn set_timeout(&self, timeout: Option) { - self.socket.set_read_timeout(timeout)?; + pub fn set_nonblocking(&mut self, nb: bool) { + self.socket.set_nonblocking(nb)?; } } From 4fbcdad49e6227e67689b32544634b4a67592f96 Mon Sep 17 00:00:00 2001 From: Jett Chen Date: Thu, 7 Dec 2023 00:51:41 +0800 Subject: [PATCH 03/10] fix misconfiguration --- burrow/src/daemon/mod.rs | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/burrow/src/daemon/mod.rs b/burrow/src/daemon/mod.rs index 719bab5..b42e350 100644 --- a/burrow/src/daemon/mod.rs +++ b/burrow/src/daemon/mod.rs @@ -46,7 +46,7 @@ pub async fn daemon_main() -> Result<()> { let mut inst = DaemonInstance::new(commands_rx, response_tx); let mut _tun = tun::TunInterface::new()?; - _tun.set_ipv4_addr(Ipv4Addr::from([192, 168, 1, 10]))?; + _tun.set_ipv4_addr(Ipv4Addr::from([10,13,13,2]))?; _tun.set_nonblocking(true)?; let tun = tun::tokio::TunInterface::new(_tun)?; From 48aba8ccb6e4940db58d0bd75f1c6f22aac8469a Mon Sep 17 00:00:00 2001 From: Jett Chen Date: Thu, 7 Dec 2023 00:51:52 +0800 Subject: [PATCH 04/10] add write to network on received packets --- burrow/src/wireguard/iface.rs | 6 +++--- burrow/src/wireguard/pcb.rs | 12 ++++++++++-- 2 files changed, 13 insertions(+), 5 deletions(-) diff --git a/burrow/src/wireguard/iface.rs b/burrow/src/wireguard/iface.rs index cfd19a5..7d1b1ec 100755 --- a/burrow/src/wireguard/iface.rs +++ b/burrow/src/wireguard/iface.rs @@ -84,12 +84,12 @@ pub struct Interface { impl Interface { #[throws] pub fn new>(tun: TunInterface, peers: I) -> Self { - let mut pcbs: IndexedPcbs = peers + let tun = Arc::new(RwLock::new(tun)); + let pcbs: IndexedPcbs = peers .into_iter() - .map(|peer| PeerPcb::new(peer)) + .map(|peer| PeerPcb::new(peer, tun.clone())) .collect::>()?; - let tun = Arc::new(RwLock::new(tun)); let pcbs = Arc::new(pcbs); Self { tun, pcbs } } diff --git a/burrow/src/wireguard/pcb.rs b/burrow/src/wireguard/pcb.rs index 2e467b1..313913e 100755 --- a/burrow/src/wireguard/pcb.rs +++ b/burrow/src/wireguard/pcb.rs @@ -13,6 +13,7 @@ use tokio::{net::UdpSocket, task::JoinHandle}; use tokio::sync::{Mutex, RwLock}; use tokio::time::timeout; use uuid::uuid; +use tun::tokio::TunInterface; use super::{ iface::PacketInterface, @@ -27,11 +28,12 @@ pub struct PeerPcb { pub handle: Option>, socket: Option, tunnel: RwLock, + tun_interface: Arc> } impl PeerPcb { #[throws] - pub fn new(peer: Peer) -> Self { + pub fn new(peer: Peer, tun_interface: Arc>) -> Self { let tunnel = RwLock::new(Tunnel::new(peer.private_key, peer.public_key, peer.preshared_key, None, 1, None) .map_err(|s| anyhow::anyhow!("{}", s))?); @@ -41,6 +43,7 @@ impl PeerPcb { handle: None, socket: None, tunnel, + tun_interface } } @@ -102,9 +105,14 @@ impl PeerPcb { } TunnResult::WriteToTunnelV4(packet, addr) => { tracing::debug!("WriteToTunnelV4: {:?}, {:?}", packet, addr); + self.tun_interface.read().await.send(packet).await?; + continue; + } + TunnResult::WriteToTunnelV6(packet, addr) => { + tracing::debug!("WriteToTunnelV6: {:?}, {:?}", packet, addr); + self.tun_interface.read().await.send(packet).await?; continue; } - e => panic!("Unexpected result from decapsulate: {:?}", e), } } return Ok(len) From 60e5d1f8fd5283487c8704d12e074945bef8ae44 Mon Sep 17 00:00:00 2001 From: Jett Chen Date: Thu, 7 Dec 2023 11:45:42 +0800 Subject: [PATCH 05/10] Update daemon --- burrow/src/daemon/instance.rs | 13 +++++++++---- burrow/src/daemon/mod.rs | 12 ++++++++++-- burrow/src/wireguard/iface.rs | 6 +++--- tun/src/tokio/mod.rs | 2 +- 4 files changed, 23 insertions(+), 10 deletions(-) diff --git a/burrow/src/daemon/instance.rs b/burrow/src/daemon/instance.rs index f807ba2..bb94897 100644 --- a/burrow/src/daemon/instance.rs +++ b/burrow/src/daemon/instance.rs @@ -1,13 +1,14 @@ +use std::ops::Deref; use tracing::{debug, info, warn}; use DaemonResponse; -use tun::TunInterface; +use tun::tokio::TunInterface; use crate::daemon::response::{DaemonResponseData, ServerConfig, ServerInfo}; use super::*; pub struct DaemonInstance { rx: async_channel::Receiver, sx: async_channel::Sender, - tun_interface: Option, + tun_interface: Option>>, } impl DaemonInstance { @@ -19,13 +20,17 @@ impl DaemonInstance { } } + pub fn set_tun_interface(&mut self, tun_interface: Arc>) { + self.tun_interface = Some(tun_interface); + } + async fn proc_command(&mut self, command: DaemonCommand) -> Result { info!("Daemon got command: {:?}", command); match command { DaemonCommand::Start(st) => { if self.tun_interface.is_none() { debug!("Daemon attempting start tun interface."); - self.tun_interface = Some(st.tun.open()?); + self.tun_interface = Some(Arc::new(RwLock::new(TunInterface::new(st.tun.open()?)?))); info!("Daemon started tun interface"); } else { warn!("Got start, but tun interface already up."); @@ -39,7 +44,7 @@ impl DaemonInstance { info!("{:?}", ti); Ok( DaemonResponseData::ServerInfo( - ServerInfo::try_from(ti)? + ServerInfo::try_from(ti.read().await.inner.get_ref())? ) ) } diff --git a/burrow/src/daemon/mod.rs b/burrow/src/daemon/mod.rs index b42e350..b44efc1 100644 --- a/burrow/src/daemon/mod.rs +++ b/burrow/src/daemon/mod.rs @@ -1,4 +1,5 @@ use std::net::{Ipv4Addr, SocketAddr, ToSocketAddrs}; +use std::sync::Arc; mod command; @@ -11,6 +12,7 @@ use base64::{engine::general_purpose, Engine as _}; pub use command::{DaemonCommand, DaemonStartOptions}; use fehler::throws; use ip_network::{IpNetwork, Ipv4Network}; +use tokio::sync::RwLock; use instance::DaemonInstance; use crate::wireguard::{StaticSecret, Peer, Interface, PublicKey}; pub use net::DaemonClient; @@ -19,6 +21,7 @@ pub use net::DaemonClient; pub use net::start_srv; pub use response::{DaemonResponseData, DaemonResponse, ServerInfo}; +use crate::daemon::net::listen; #[throws] fn parse_key(string: &str) -> [u8; 32] { @@ -49,12 +52,16 @@ pub async fn daemon_main() -> Result<()> { _tun.set_ipv4_addr(Ipv4Addr::from([10,13,13,2]))?; _tun.set_nonblocking(true)?; let tun = tun::tokio::TunInterface::new(_tun)?; + let tun_ref = Arc::new(RwLock::new(tun)); let private_key = parse_secret_key("GNqIAOCRxjl/cicZyvkvpTklgQuUmGUIEkH7IXF/sEE=")?; let public_key = parse_public_key("uy75leriJay0+oHLhRMpV+A5xAQ0hCJ+q7Ww81AOvT4=")?; let preshared_key = Some(parse_key("s7lx/mg+reVEMnGnqeyYOQkzD86n2+gYnx1M9ygi08k=")?); let endpoint = "wg.burrow.rs:51820".to_socket_addrs()?.next().unwrap(); - let iface = Interface::new(tun, vec![Peer { + + inst.set_tun_interface(tun_ref.clone()); + + let iface = Interface::new(tun_ref, vec![Peer { endpoint, private_key, public_key, @@ -62,6 +69,7 @@ pub async fn daemon_main() -> Result<()> { allowed_ips: vec![IpNetwork::V4(Ipv4Network::DEFAULT_ROUTE)], }])?; - iface.run().await; + tokio::try_join!(iface.run(), inst.run(), listen(commands_tx, response_rx)) + .map(|_| {()}); Ok(()) } diff --git a/burrow/src/wireguard/iface.rs b/burrow/src/wireguard/iface.rs index 7d1b1ec..9f5dae4 100755 --- a/burrow/src/wireguard/iface.rs +++ b/burrow/src/wireguard/iface.rs @@ -83,8 +83,7 @@ pub struct Interface { impl Interface { #[throws] - pub fn new>(tun: TunInterface, peers: I) -> Self { - let tun = Arc::new(RwLock::new(tun)); + pub fn new>(tun: Arc>, peers: I) -> Self { let pcbs: IndexedPcbs = peers .into_iter() .map(|peer| PeerPcb::new(peer, tun.clone())) @@ -94,7 +93,7 @@ impl Interface { Self { tun, pcbs } } - pub async fn run(self) { + pub async fn run(self) -> anyhow::Result<()> { let pcbs = self.pcbs.clone(); let tun = self.tun.clone(); log::info!("starting interface"); @@ -187,5 +186,6 @@ impl Interface { } log::debug!("preparing to join.."); join_all(tsks).await; + Ok(()) } } diff --git a/tun/src/tokio/mod.rs b/tun/src/tokio/mod.rs index 2ade0a1..599e92c 100644 --- a/tun/src/tokio/mod.rs +++ b/tun/src/tokio/mod.rs @@ -5,7 +5,7 @@ use tracing::instrument; #[derive(Debug)] pub struct TunInterface { - inner: AsyncFd, + pub inner: AsyncFd, } impl TunInterface { From 347b78453f9892b5f1df0bccf24be86b008b26cf Mon Sep 17 00:00:00 2001 From: Jett Chen Date: Thu, 7 Dec 2023 12:12:48 +0800 Subject: [PATCH 06/10] Fix Duplicate Packet Error --- burrow/src/wireguard/pcb.rs | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/burrow/src/wireguard/pcb.rs b/burrow/src/wireguard/pcb.rs index 313913e..f92acdc 100755 --- a/burrow/src/wireguard/pcb.rs +++ b/burrow/src/wireguard/pcb.rs @@ -106,12 +106,12 @@ impl PeerPcb { TunnResult::WriteToTunnelV4(packet, addr) => { tracing::debug!("WriteToTunnelV4: {:?}, {:?}", packet, addr); self.tun_interface.read().await.send(packet).await?; - continue; + break; } TunnResult::WriteToTunnelV6(packet, addr) => { tracing::debug!("WriteToTunnelV6: {:?}, {:?}", packet, addr); self.tun_interface.read().await.send(packet).await?; - continue; + break; } } } From ede0d13bca8ee8b86794de72cae57ac29546f391 Mon Sep 17 00:00:00 2001 From: Jett Chen Date: Sat, 9 Dec 2023 20:13:49 +0800 Subject: [PATCH 07/10] Start Tun Interface at Daemon Command --- burrow/src/daemon/command.rs | 2 +- burrow/src/daemon/instance.rs | 36 +++++++++++++++++++++++++++-------- burrow/src/daemon/mod.rs | 15 ++++----------- burrow/src/main.rs | 7 ++++++- burrow/src/wireguard/iface.rs | 35 +++++++++++++++++----------------- burrow/src/wireguard/pcb.rs | 18 ++++++++---------- tun/src/options.rs | 9 +++++++++ tun/src/tokio/mod.rs | 3 ++- tun/src/unix/apple/mod.rs | 10 ++++++++-- 9 files changed, 83 insertions(+), 52 deletions(-) diff --git a/burrow/src/daemon/command.rs b/burrow/src/daemon/command.rs index a5a1f30..cbe7f15 100644 --- a/burrow/src/daemon/command.rs +++ b/burrow/src/daemon/command.rs @@ -12,7 +12,7 @@ pub enum DaemonCommand { #[derive(Debug, Clone, Default, Serialize, Deserialize, JsonSchema)] pub struct DaemonStartOptions { - pub(super) tun: TunOptions, + pub tun: TunOptions, } #[test] diff --git a/burrow/src/daemon/instance.rs b/burrow/src/daemon/instance.rs index bb94897..073bc37 100644 --- a/burrow/src/daemon/instance.rs +++ b/burrow/src/daemon/instance.rs @@ -1,22 +1,35 @@ -use std::ops::Deref; +use tokio::task::JoinHandle; use tracing::{debug, info, warn}; use DaemonResponse; use tun::tokio::TunInterface; use crate::daemon::response::{DaemonResponseData, ServerConfig, ServerInfo}; use super::*; +enum RunState{ + Running(JoinHandle>), + Idle, +} + pub struct DaemonInstance { rx: async_channel::Receiver, sx: async_channel::Sender, tun_interface: Option>>, + wg_interface: Arc>, + wg_state: RunState, } impl DaemonInstance { - pub fn new(rx: async_channel::Receiver, sx: async_channel::Sender) -> Self { + pub fn new( + rx: async_channel::Receiver, + sx: async_channel::Sender, + wg_interface: Arc>, + ) -> Self { Self { rx, sx, + wg_interface, tun_interface: None, + wg_state: RunState::Idle, } } @@ -28,12 +41,19 @@ impl DaemonInstance { info!("Daemon got command: {:?}", command); match command { DaemonCommand::Start(st) => { - if self.tun_interface.is_none() { - debug!("Daemon attempting start tun interface."); - self.tun_interface = Some(Arc::new(RwLock::new(TunInterface::new(st.tun.open()?)?))); - info!("Daemon started tun interface"); - } else { - warn!("Got start, but tun interface already up."); + match self.wg_state { + RunState::Running(_) => {warn!("Got start, but tun interface already up.");} + RunState::Idle => { + let tun_if = Arc::new(RwLock::new(TunInterface::new(st.tun.open()?)?)); + self.tun_interface = Some(tun_if.clone()); + self.wg_interface.write().await.set_tun(tun_if); + let tmp_wg = self.wg_interface.clone(); + let run_task = tokio::spawn(async move { + tmp_wg.read().await.run().await + }); + self.wg_state = RunState::Running(run_task); + info!("Daemon started tun interface"); + } } Ok(DaemonResponseData::None) } diff --git a/burrow/src/daemon/mod.rs b/burrow/src/daemon/mod.rs index b44efc1..8814ce2 100644 --- a/burrow/src/daemon/mod.rs +++ b/burrow/src/daemon/mod.rs @@ -46,22 +46,13 @@ fn parse_public_key(string: &str) -> PublicKey { pub async fn daemon_main() -> Result<()> { let (commands_tx, commands_rx) = async_channel::unbounded(); let (response_tx, response_rx) = async_channel::unbounded(); - let mut inst = DaemonInstance::new(commands_rx, response_tx); - - let mut _tun = tun::TunInterface::new()?; - _tun.set_ipv4_addr(Ipv4Addr::from([10,13,13,2]))?; - _tun.set_nonblocking(true)?; - let tun = tun::tokio::TunInterface::new(_tun)?; - let tun_ref = Arc::new(RwLock::new(tun)); let private_key = parse_secret_key("GNqIAOCRxjl/cicZyvkvpTklgQuUmGUIEkH7IXF/sEE=")?; let public_key = parse_public_key("uy75leriJay0+oHLhRMpV+A5xAQ0hCJ+q7Ww81AOvT4=")?; let preshared_key = Some(parse_key("s7lx/mg+reVEMnGnqeyYOQkzD86n2+gYnx1M9ygi08k=")?); let endpoint = "wg.burrow.rs:51820".to_socket_addrs()?.next().unwrap(); - inst.set_tun_interface(tun_ref.clone()); - - let iface = Interface::new(tun_ref, vec![Peer { + let iface = Interface::new(vec![Peer { endpoint, private_key, public_key, @@ -69,7 +60,9 @@ pub async fn daemon_main() -> Result<()> { allowed_ips: vec![IpNetwork::V4(Ipv4Network::DEFAULT_ROUTE)], }])?; - tokio::try_join!(iface.run(), inst.run(), listen(commands_tx, response_rx)) + let mut inst = DaemonInstance::new(commands_rx, response_tx, Arc::new(RwLock::new(iface))); + + tokio::try_join!(inst.run(), listen(commands_tx, response_rx)) .map(|_| {()}); Ok(()) } diff --git a/burrow/src/main.rs b/burrow/src/main.rs index 2e89a48..125d763 100644 --- a/burrow/src/main.rs +++ b/burrow/src/main.rs @@ -18,6 +18,7 @@ mod daemon; mod wireguard; use daemon::{DaemonClient, DaemonCommand, DaemonStartOptions}; +use tun::TunOptions; use crate::daemon::DaemonResponseData; #[derive(Parser)] @@ -65,7 +66,11 @@ struct DaemonArgs {} async fn try_start() -> Result<()> { let mut client = DaemonClient::new().await?; client - .send_command(DaemonCommand::Start(DaemonStartOptions::default())) + .send_command(DaemonCommand::Start( + DaemonStartOptions{ + tun: TunOptions::new().address("10.13.13.2") + } + )) .await .map(|_| ()) } diff --git a/burrow/src/wireguard/iface.rs b/burrow/src/wireguard/iface.rs index 9f5dae4..4a00cbe 100755 --- a/burrow/src/wireguard/iface.rs +++ b/burrow/src/wireguard/iface.rs @@ -77,25 +77,32 @@ impl FromIterator for IndexedPcbs { } pub struct Interface { - tun: Arc>, + tun: Option>>, pcbs: Arc, } impl Interface { #[throws] - pub fn new>(tun: Arc>, peers: I) -> Self { + pub fn new>(peers: I) -> Self { let pcbs: IndexedPcbs = peers .into_iter() - .map(|peer| PeerPcb::new(peer, tun.clone())) + .map(|peer| PeerPcb::new(peer)) .collect::>()?; let pcbs = Arc::new(pcbs); - Self { tun, pcbs } + Self { + pcbs, + tun: None + } } - pub async fn run(self) -> anyhow::Result<()> { + pub fn set_tun(&mut self, tun: Arc>) { + self.tun = Some(tun); + } + + pub async fn run(&self) -> anyhow::Result<()> { let pcbs = self.pcbs.clone(); - let tun = self.tun.clone(); + let tun = self.tun.clone().ok_or(anyhow::anyhow!("tun interface does not exist"))?; log::info!("starting interface"); let outgoing = async move { @@ -143,24 +150,16 @@ impl Interface { continue }, }; - - // let mut buf = [0u8; 3000]; - // match pcbs.pcbs[idx].read().await.recv(&mut buf).await { - // Ok(len) => log::debug!("received {} bytes from peer {}", len, dst_addr), - // Err(e) => { - // log::error!("failed to receive packet {}", e); - // continue - // }, - // } } }; let mut tsks = vec![]; - let tun = self.tun.clone(); + let tun = self.tun.clone().ok_or(anyhow::anyhow!("tun interface does not exist"))?; + let outgoing = tokio::task::spawn(outgoing); tsks.push(outgoing); { - let pcbs = self.pcbs; + let pcbs = &self.pcbs; for i in 0..pcbs.pcbs.len(){ let mut pcb = pcbs.pcbs[i].clone(); let tun = tun.clone(); @@ -172,7 +171,7 @@ impl Interface { return } } - let r2 = pcb.read().await.run().await; + let r2 = pcb.read().await.run(tun).await; if let Err(e) = r2 { log::error!("failed to run pcb: {}", e); return diff --git a/burrow/src/wireguard/pcb.rs b/burrow/src/wireguard/pcb.rs index f92acdc..6fcaa15 100755 --- a/burrow/src/wireguard/pcb.rs +++ b/burrow/src/wireguard/pcb.rs @@ -28,12 +28,11 @@ pub struct PeerPcb { pub handle: Option>, socket: Option, tunnel: RwLock, - tun_interface: Arc> } impl PeerPcb { #[throws] - pub fn new(peer: Peer, tun_interface: Arc>) -> Self { + pub fn new(peer: Peer) -> Self { let tunnel = RwLock::new(Tunnel::new(peer.private_key, peer.public_key, peer.preshared_key, None, 1, None) .map_err(|s| anyhow::anyhow!("{}", s))?); @@ -42,8 +41,7 @@ impl PeerPcb { allowed_ips: peer.allowed_ips, handle: None, socket: None, - tunnel, - tun_interface + tunnel } } @@ -56,22 +54,22 @@ impl PeerPcb { Ok(()) } - pub async fn run(&self) -> Result<(), Error> { + pub async fn run(&self, tun_interface: Arc>) -> Result<(), Error> { let mut buf = [0u8; 3000]; log::debug!("starting read loop for pcb..."); loop { tracing::debug!("waiting for packet"); - let len = self.recv(&mut buf).await?; + let len = self.recv(&mut buf, tun_interface.clone()).await?; tracing::debug!("received {} bytes", len); } } - pub async fn recv(&self, buf: &mut [u8]) -> Result { + pub async fn recv(&self, buf: &mut [u8], tun_interface: Arc>) -> Result { log::debug!("starting read loop for pcb... for {:?}", &self); let rid: i32 = random(); log::debug!("start read loop {}", rid); loop{ - // log::debug!("{}: waiting for packet", rid); + log::debug!("{}: waiting for packet", rid); let Some(socket) = &self.socket else { continue }; @@ -105,12 +103,12 @@ impl PeerPcb { } TunnResult::WriteToTunnelV4(packet, addr) => { tracing::debug!("WriteToTunnelV4: {:?}, {:?}", packet, addr); - self.tun_interface.read().await.send(packet).await?; + tun_interface.read().await.send(packet).await?; break; } TunnResult::WriteToTunnelV6(packet, addr) => { tracing::debug!("WriteToTunnelV6: {:?}, {:?}", packet, addr); - self.tun_interface.read().await.send(packet).await?; + tun_interface.read().await.send(packet).await?; break; } } diff --git a/tun/src/options.rs b/tun/src/options.rs index 3fe5a13..82cadfd 100644 --- a/tun/src/options.rs +++ b/tun/src/options.rs @@ -13,6 +13,10 @@ pub struct TunOptions { pub(crate) no_pi: Option<()>, /// (Linux) Avoid opening an existing persistant device. pub(crate) tun_excl: Option<()>, + /// (MacOS) Whether to seek the first available utun device. + pub(crate) seek_utun: Option<()>, + /// (Linux) The IP address of the tun interface. + pub(crate) address: Option, } impl TunOptions { @@ -27,6 +31,11 @@ impl TunOptions { pub fn tun_excl(mut self, enable: bool) { self.tun_excl = enable.then_some(()); } + pub fn address(mut self, address: impl ToString) -> Self { + self.address = Some(address.to_string()); + self + } + #[throws] pub fn open(self) -> TunInterface { TunInterface::new_with_options(self)? } } diff --git a/tun/src/tokio/mod.rs b/tun/src/tokio/mod.rs index 599e92c..fb924ff 100644 --- a/tun/src/tokio/mod.rs +++ b/tun/src/tokio/mod.rs @@ -10,7 +10,8 @@ pub struct TunInterface { impl TunInterface { #[instrument] - pub fn new(tun: crate::TunInterface) -> io::Result { + pub fn new(mut tun: crate::TunInterface) -> io::Result { + tun.set_nonblocking(true)?; Ok(Self { inner: AsyncFd::new(tun)?, }) diff --git a/tun/src/unix/apple/mod.rs b/tun/src/unix/apple/mod.rs index 83dbdc1..b419294 100644 --- a/tun/src/unix/apple/mod.rs +++ b/tun/src/unix/apple/mod.rs @@ -33,8 +33,14 @@ impl TunInterface { #[throws] #[instrument] - pub fn new_with_options(_: TunOptions) -> TunInterface { - TunInterface::connect(0)? + pub fn new_with_options(options: TunOptions) -> TunInterface { + let ti = TunInterface::connect(0)?; + if let Some(addr) = options.address{ + if let Ok(addr) = addr.parse() { + ti.set_ipv4_addr(addr)?; + } + } + ti } #[throws] From db1750a045ab149bb9afa9595672128bd28e45c5 Mon Sep 17 00:00:00 2001 From: Jett Chen Date: Sun, 10 Dec 2023 03:44:31 +0800 Subject: [PATCH 08/10] checkpoint --- Apple/NetworkExtension/BurrowIpc.swift | 2 +- Apple/NetworkExtension/DataTypes.swift | 30 +++++- .../PacketTunnelProvider.swift | 12 ++- burrow/src/apple.rs | 10 +- burrow/src/daemon/command.rs | 26 ++--- burrow/src/daemon/instance.rs | 61 ++++++++---- burrow/src/daemon/mod.rs | 32 ++++-- burrow/src/daemon/net/apple.rs | 17 ++-- burrow/src/daemon/net/mod.rs | 1 - burrow/src/daemon/net/systemd.rs | 18 +++- burrow/src/daemon/net/unix.rs | 58 ++++++----- burrow/src/daemon/net/windows.rs | 5 +- burrow/src/daemon/response.rs | 99 ++++++++++--------- ...ommand__daemoncommand_serialization-2.snap | 4 +- ...ommand__daemoncommand_serialization-3.snap | 4 +- ...ommand__daemoncommand_serialization-4.snap | 4 +- ...ommand__daemoncommand_serialization-5.snap | 5 + ..._command__daemoncommand_serialization.snap | 2 +- ...n__response__response_serialization-4.snap | 2 +- burrow/src/lib.rs | 25 +---- burrow/src/main.rs | 45 ++++----- burrow/src/wireguard/iface.rs | 82 ++++++++------- burrow/src/wireguard/noise/handshake.rs | 41 ++++---- burrow/src/wireguard/noise/mod.rs | 24 ++--- burrow/src/wireguard/noise/rate_limiter.rs | 24 ++--- burrow/src/wireguard/noise/session.rs | 16 +-- burrow/src/wireguard/noise/timers.rs | 10 +- burrow/src/wireguard/pcb.rs | 53 ++++++---- burrow/src/wireguard/peer.rs | 2 +- tun/build.rs | 2 +- tun/src/lib.rs | 2 +- tun/src/options.rs | 33 +++++-- tun/src/tokio/mod.rs | 20 ++-- tun/src/unix/apple/mod.rs | 14 ++- tun/src/unix/apple/sys.rs | 13 +-- tun/src/unix/linux/mod.rs | 8 +- tun/src/unix/mod.rs | 46 ++++++++- tun/src/windows/mod.rs | 17 ++-- tun/tests/configure.rs | 4 +- 39 files changed, 514 insertions(+), 359 deletions(-) create mode 100644 burrow/src/daemon/snapshots/burrow__daemon__command__daemoncommand_serialization-5.snap diff --git a/Apple/NetworkExtension/BurrowIpc.swift b/Apple/NetworkExtension/BurrowIpc.swift index 7f18679..279cdf1 100644 --- a/Apple/NetworkExtension/BurrowIpc.swift +++ b/Apple/NetworkExtension/BurrowIpc.swift @@ -113,7 +113,7 @@ final class BurrowIpc { return data } - func request(_ request: Request, type: U.Type) async throws -> U { + func request(_ request: any Request, type: U.Type) async throws -> U { do { var data: Data = try JSONEncoder().encode(request) data.append(contentsOf: [10]) diff --git a/Apple/NetworkExtension/DataTypes.swift b/Apple/NetworkExtension/DataTypes.swift index b228d77..6b3a070 100644 --- a/Apple/NetworkExtension/DataTypes.swift +++ b/Apple/NetworkExtension/DataTypes.swift @@ -7,16 +7,40 @@ enum BurrowError: Error { case resultIsNone } -protocol Request: Codable { +protocol Request: Codable where T: Codable{ + associatedtype T var id: UInt { get set } - var command: String { get set } + var command: T { get set } } -struct BurrowRequest: Request { +struct BurrowSingleCommand: Request { var id: UInt var command: String } +struct BurrowRequest: Request where T: Codable{ + var id: UInt + var command: T +} + +struct BurrowStartRequest: Codable { + struct TunOptions: Codable{ + let name: String? + let no_pi: Bool + let tun_excl: Bool + let seek_utun: Int? + let address: String? + } + struct StartOptions: Codable{ + let tun: TunOptions + } + let Start: StartOptions +} + +func start_req_fd(id: UInt, fd: Int) -> BurrowRequest { + return BurrowRequest(id: id, command: BurrowStartRequest(Start: BurrowStartRequest.StartOptions(tun: BurrowStartRequest.TunOptions(name: nil, no_pi: false, tun_excl: false, seek_utun: fd, address: nil)))) +} + struct Response: Decodable where T: Decodable { var id: UInt var result: T diff --git a/Apple/NetworkExtension/PacketTunnelProvider.swift b/Apple/NetworkExtension/PacketTunnelProvider.swift index 4b72115..8260aa0 100644 --- a/Apple/NetworkExtension/PacketTunnelProvider.swift +++ b/Apple/NetworkExtension/PacketTunnelProvider.swift @@ -17,7 +17,7 @@ class PacketTunnelProvider: NEPacketTunnelProvider { logger.info("Started server") Task { do { - let command = BurrowRequest(id: 0, command: "ServerConfig") + let command = BurrowSingleCommand(id: 0, command: "ServerConfig") guard let data = try await client?.request(command, type: Response>.self) else { throw BurrowError.cantParseResult @@ -32,6 +32,16 @@ class PacketTunnelProvider: NEPacketTunnelProvider { } try await self.setTunnelNetworkSettings(tunNs) self.logger.info("Set remote tunnel address to \(tunNs.tunnelRemoteAddress)") + +// let tunFd = self.packetFlow.value(forKeyPath: "socket.fileDescriptor") as! Int; +// self.logger.info("Found File Descriptor: \(tunFd)") + let start_command = start_req_fd(id: 1, fd: 0) + guard let data = try await client?.request(start_command, type: Response>.self) + else { + throw BurrowError.cantParseResult + } + let encoded_startres = try JSONEncoder().encode(data.result) + self.logger.log("Received start server response: \(String(decoding: encoded_startres, as: UTF8.self))") completionHandler(nil) } catch { self.logger.error("An error occurred: \(error)") diff --git a/burrow/src/apple.rs b/burrow/src/apple.rs index 0a96877..dd50fc2 100644 --- a/burrow/src/apple.rs +++ b/burrow/src/apple.rs @@ -1,15 +1,15 @@ -use tracing::{debug, Subscriber}; use tracing::instrument::WithSubscriber; +use tracing::{debug, Subscriber}; use tracing_oslog::OsLogger; -use tracing_subscriber::FmtSubscriber; use tracing_subscriber::layer::SubscriberExt; +use tracing_subscriber::FmtSubscriber; pub use crate::daemon::start_srv; #[no_mangle] pub extern "C" fn initialize_oslog() { - let collector = tracing_subscriber::registry() - .with(OsLogger::new("com.hackclub.burrow", "backend")); + let collector = + tracing_subscriber::registry().with(OsLogger::new("com.hackclub.burrow", "backend")); tracing::subscriber::set_global_default(collector).unwrap(); debug!("Initialized oslog tracing in libburrow rust FFI"); -} \ No newline at end of file +} diff --git a/burrow/src/daemon/command.rs b/burrow/src/daemon/command.rs index cbe7f15..776e172 100644 --- a/burrow/src/daemon/command.rs +++ b/burrow/src/daemon/command.rs @@ -17,16 +17,20 @@ pub struct DaemonStartOptions { #[test] fn test_daemoncommand_serialization() { + insta::assert_snapshot!(serde_json::to_string(&DaemonCommand::Start( + DaemonStartOptions::default() + )) + .unwrap()); insta::assert_snapshot!( - serde_json::to_string(&DaemonCommand::Start(DaemonStartOptions::default())).unwrap() + serde_json::to_string(&DaemonCommand::Start(DaemonStartOptions { + tun: TunOptions { + seek_utun: true, + ..TunOptions::default() + } + })) + .unwrap() ); - insta::assert_snapshot!( - serde_json::to_string(&DaemonCommand::ServerInfo).unwrap() - ); - insta::assert_snapshot!( - serde_json::to_string(&DaemonCommand::Stop).unwrap() - ); - insta::assert_snapshot!( - serde_json::to_string(&DaemonCommand::ServerConfig).unwrap() - ) -} \ No newline at end of file + insta::assert_snapshot!(serde_json::to_string(&DaemonCommand::ServerInfo).unwrap()); + insta::assert_snapshot!(serde_json::to_string(&DaemonCommand::Stop).unwrap()); + insta::assert_snapshot!(serde_json::to_string(&DaemonCommand::ServerConfig).unwrap()) +} diff --git a/burrow/src/daemon/instance.rs b/burrow/src/daemon/instance.rs index 073bc37..c79da05 100644 --- a/burrow/src/daemon/instance.rs +++ b/burrow/src/daemon/instance.rs @@ -1,11 +1,12 @@ -use tokio::task::JoinHandle; -use tracing::{debug, info, warn}; -use DaemonResponse; -use tun::tokio::TunInterface; -use crate::daemon::response::{DaemonResponseData, ServerConfig, ServerInfo}; use super::*; +use crate::daemon::response::{DaemonResponseData, ServerConfig, ServerInfo}; +use tokio::task::JoinHandle; +use tracing::field::debug; +use tracing::{debug, info, warn}; +use tun::tokio::TunInterface; +use DaemonResponse; -enum RunState{ +enum RunState { Running(JoinHandle>), Idle, } @@ -42,34 +43,53 @@ impl DaemonInstance { match command { DaemonCommand::Start(st) => { match self.wg_state { - RunState::Running(_) => {warn!("Got start, but tun interface already up.");} + RunState::Running(_) => { + warn!("Got start, but tun interface already up."); + } RunState::Idle => { + debug!("Creating new TunInterface"); let tun_if = Arc::new(RwLock::new(TunInterface::new(st.tun.open()?)?)); + debug!("TunInterface created: {:?}", tun_if); + + debug!("Setting tun_interface"); self.tun_interface = Some(tun_if.clone()); + debug!("tun_interface set: {:?}", self.tun_interface); + + debug!("Setting tun on wg_interface"); self.wg_interface.write().await.set_tun(tun_if); + debug!("tun set on wg_interface"); + + debug!("Cloning wg_interface"); let tmp_wg = self.wg_interface.clone(); + debug!("wg_interface cloned"); + + debug!("Spawning run task"); let run_task = tokio::spawn(async move { - tmp_wg.read().await.run().await + debug!("Running wg_interface"); + let twlock = tmp_wg.read().await; + debug!("wg_interface read lock acquired"); + twlock.run().await }); + debug!("Run task spawned: {:?}", run_task); + + debug!("Setting wg_state to Running"); self.wg_state = RunState::Running(run_task); + debug!("wg_state set to Running"); + info!("Daemon started tun interface"); } } Ok(DaemonResponseData::None) } - DaemonCommand::ServerInfo => { - match &self.tun_interface { - None => {Ok(DaemonResponseData::None)} - Some(ti) => { - info!("{:?}", ti); - Ok( - DaemonResponseData::ServerInfo( - ServerInfo::try_from(ti.read().await.inner.get_ref())? - ) - ) - } + DaemonCommand::ServerInfo => match &self.tun_interface { + None => Ok(DaemonResponseData::None), + Some(ti) => { + info!("{:?}", ti); + Ok(DaemonResponseData::ServerInfo(ServerInfo::try_from( + ti.read().await.inner.get_ref(), + )?)) } - } + }, DaemonCommand::Stop => { if self.tun_interface.is_some() { self.tun_interface = None; @@ -86,6 +106,7 @@ impl DaemonInstance { } pub async fn run(&mut self) -> Result<()> { + tracing::info!("BEGIN"); while let Ok(command) = self.rx.recv().await { let response = self.proc_command(command).await; info!("Daemon response: {:?}", response); diff --git a/burrow/src/daemon/mod.rs b/burrow/src/daemon/mod.rs index 8814ce2..1aa6ea4 100644 --- a/burrow/src/daemon/mod.rs +++ b/burrow/src/daemon/mod.rs @@ -1,27 +1,26 @@ -use std::net::{Ipv4Addr, SocketAddr, ToSocketAddrs}; +use std::net::{IpAddr, Ipv4Addr, SocketAddr, SocketAddrV4, ToSocketAddrs}; use std::sync::Arc; - mod command; mod instance; mod net; mod response; +use crate::wireguard::{Interface, Peer, PublicKey, StaticSecret}; use anyhow::{Error, Result}; use base64::{engine::general_purpose, Engine as _}; pub use command::{DaemonCommand, DaemonStartOptions}; use fehler::throws; -use ip_network::{IpNetwork, Ipv4Network}; -use tokio::sync::RwLock; use instance::DaemonInstance; -use crate::wireguard::{StaticSecret, Peer, Interface, PublicKey}; +use ip_network::{IpNetwork, Ipv4Network}; pub use net::DaemonClient; +use tokio::sync::RwLock; #[cfg(target_vendor = "apple")] pub use net::start_srv; -pub use response::{DaemonResponseData, DaemonResponse, ServerInfo}; use crate::daemon::net::listen; +pub use response::{DaemonResponse, DaemonResponseData, ServerInfo}; #[throws] fn parse_key(string: &str) -> [u8; 32] { @@ -50,7 +49,7 @@ pub async fn daemon_main() -> Result<()> { let private_key = parse_secret_key("GNqIAOCRxjl/cicZyvkvpTklgQuUmGUIEkH7IXF/sEE=")?; let public_key = parse_public_key("uy75leriJay0+oHLhRMpV+A5xAQ0hCJ+q7Ww81AOvT4=")?; let preshared_key = Some(parse_key("s7lx/mg+reVEMnGnqeyYOQkzD86n2+gYnx1M9ygi08k=")?); - let endpoint = "wg.burrow.rs:51820".to_socket_addrs()?.next().unwrap(); + let endpoint = SocketAddr::new(IpAddr::V4(Ipv4Addr::new(198, 18, 6, 180)), 51820); // DNS lookup under macos fails, somehow let iface = Interface::new(vec![Peer { endpoint, @@ -62,7 +61,22 @@ pub async fn daemon_main() -> Result<()> { let mut inst = DaemonInstance::new(commands_rx, response_tx, Arc::new(RwLock::new(iface))); - tokio::try_join!(inst.run(), listen(commands_tx, response_rx)) - .map(|_| {()}); + tracing::info!("Starting daemon jobs..."); + + let inst_job = tokio::spawn(async move { + let res = inst.run().await; + if let Err(e) = res { + tracing::error!("Error when running instance: {}", e); + } + }); + + let listen_job = tokio::spawn(async move { + let res = listen(commands_tx, response_rx).await; + if let Err(e) = res { + tracing::error!("Error when listening: {}", e); + } + }); + + tokio::try_join!(inst_job, listen_job).map(|_| ()); Ok(()) } diff --git a/burrow/src/daemon/net/apple.rs b/burrow/src/daemon/net/apple.rs index e53bdaa..1242dfe 100644 --- a/burrow/src/daemon/net/apple.rs +++ b/burrow/src/daemon/net/apple.rs @@ -1,10 +1,12 @@ +use crate::daemon::{daemon_main, DaemonClient}; +use std::future::Future; use std::thread; use tokio::runtime::Runtime; -use tracing::error; -use crate::daemon::{daemon_main, DaemonClient}; +use tracing::{error, info}; #[no_mangle] -pub extern "C" fn start_srv(){ +pub extern "C" fn start_srv() { + info!("Rust: Starting server"); let _handle = thread::spawn(move || { let rt = Runtime::new().unwrap(); rt.block_on(async { @@ -16,9 +18,12 @@ pub extern "C" fn start_srv(){ let rt = Runtime::new().unwrap(); rt.block_on(async { loop { - if let Ok(_) = DaemonClient::new().await{ - break + match DaemonClient::new().await { + Ok(_) => break, + Err(e) => { + // error!("Error when connecting to daemon: {}", e) + } } } }); -} \ No newline at end of file +} diff --git a/burrow/src/daemon/net/mod.rs b/burrow/src/daemon/net/mod.rs index e5865a3..d369f40 100644 --- a/burrow/src/daemon/net/mod.rs +++ b/burrow/src/daemon/net/mod.rs @@ -29,4 +29,3 @@ pub struct DaemonRequest { pub id: u32, pub command: DaemonCommand, } - diff --git a/burrow/src/daemon/net/systemd.rs b/burrow/src/daemon/net/systemd.rs index 1a41f76..8a2b29c 100644 --- a/burrow/src/daemon/net/systemd.rs +++ b/burrow/src/daemon/net/systemd.rs @@ -1,13 +1,23 @@ -pub async fn listen(cmd_tx: async_channel::Sender, rsp_rx: async_channel::Receiver) -> Result<()> { - if !libsystemd::daemon::booted() || listen_with_systemd(cmd_tx.clone(), rsp_rx.clone()).await.is_err() { +pub async fn listen( + cmd_tx: async_channel::Sender, + rsp_rx: async_channel::Receiver, +) -> Result<()> { + if !libsystemd::daemon::booted() + || listen_with_systemd(cmd_tx.clone(), rsp_rx.clone()) + .await + .is_err() + { unix::listen(cmd_tx, rsp_rx).await?; } Ok(()) } -async fn listen_with_systemd(cmd_tx: async_channel::Sender, rsp_rx: async_channel::Receiver) -> Result<()> { +async fn listen_with_systemd( + cmd_tx: async_channel::Sender, + rsp_rx: async_channel::Receiver, +) -> Result<()> { let fds = libsystemd::activation::receive_descriptors(false)?; - super::unix::listen_with_optional_fd(cmd_tx, rsp_rx,Some(fds[0].clone().into_raw_fd())).await + super::unix::listen_with_optional_fd(cmd_tx, rsp_rx, Some(fds[0].clone().into_raw_fd())).await } pub type DaemonClient = unix::DaemonClient; diff --git a/burrow/src/daemon/net/unix.rs b/burrow/src/daemon/net/unix.rs index 928473b..d0e5b26 100644 --- a/burrow/src/daemon/net/unix.rs +++ b/burrow/src/daemon/net/unix.rs @@ -1,23 +1,25 @@ use super::*; +use anyhow::anyhow; +use log::log; +use std::hash::Hash; +use std::path::PathBuf; use std::{ - ascii, io, os::{ + ascii, io, + os::{ fd::{FromRawFd, RawFd}, unix::net::UnixListener as StdUnixListener, }, - path::Path}; -use std::hash::Hash; -use std::path::PathBuf; -use anyhow::anyhow; -use log::log; + path::Path, +}; use tracing::info; +use crate::daemon::{DaemonCommand, DaemonResponse, DaemonResponseData}; use anyhow::Result; use tokio::{ io::{AsyncBufReadExt, AsyncWriteExt, BufReader}, net::{UnixListener, UnixStream}, }; use tracing::debug; -use crate::daemon::{DaemonCommand, DaemonResponse, DaemonResponseData}; #[cfg(not(target_vendor = "apple"))] const UNIX_SOCKET_PATH: &str = "/run/burrow.sock"; @@ -26,16 +28,18 @@ const UNIX_SOCKET_PATH: &str = "/run/burrow.sock"; const UNIX_SOCKET_PATH: &str = "burrow.sock"; #[cfg(target_os = "macos")] -fn fetch_socket_path() -> Option{ +fn fetch_socket_path() -> Option { let tries = vec![ "burrow.sock".to_string(), - format!("{}/Library/Containers/com.hackclub.burrow.network/Data/burrow.sock", - std::env::var("HOME").unwrap_or_default()) - .to_string(), + format!( + "{}/Library/Containers/com.hackclub.burrow.network/Data/burrow.sock", + std::env::var("HOME").unwrap_or_default() + ) + .to_string(), ]; - for path in tries{ + for path in tries { let path = PathBuf::from(path); - if path.exists(){ + if path.exists() { return Some(path); } } @@ -43,11 +47,14 @@ fn fetch_socket_path() -> Option{ } #[cfg(not(target_os = "macos"))] -fn fetch_socket_path() -> Option{ +fn fetch_socket_path() -> Option { Some(Path::new(UNIX_SOCKET_PATH).to_path_buf()) } -pub async fn listen(cmd_tx: async_channel::Sender, rsp_rx: async_channel::Receiver) -> Result<()> { +pub async fn listen( + cmd_tx: async_channel::Sender, + rsp_rx: async_channel::Receiver, +) -> Result<()> { listen_with_optional_fd(cmd_tx, rsp_rx, None).await } @@ -69,14 +76,12 @@ pub(crate) async fn listen_with_optional_fd( listener } else { // Won't help all that much, if we use the async version of fs. - if let Some(par) = path.parent(){ - std::fs::create_dir_all( - par - )?; + if let Some(par) = path.parent() { + std::fs::create_dir_all(par)?; } - match std::fs::remove_file(path){ - Err(e) if e.kind()==io::ErrorKind::NotFound => {Ok(())} - stuff => stuff + match std::fs::remove_file(path) { + Err(e) if e.kind() == io::ErrorKind::NotFound => Ok(()), + stuff => stuff, }?; info!("Relative path: {}", path.to_string_lossy()); UnixListener::bind(path)? @@ -98,18 +103,18 @@ pub(crate) async fn listen_with_optional_fd( while let Ok(Some(line)) = lines.next_line().await { info!("Got line: {}", line); debug!("Line raw data: {:?}", line.as_bytes()); - let mut res : DaemonResponse = DaemonResponseData::None.into(); + let mut res: DaemonResponse = DaemonResponseData::None.into(); let req = match serde_json::from_str::(&line) { Ok(req) => Some(req), Err(e) => { res.result = Err(e.to_string()); + tracing::error!("Failed to parse request: {}", e); None } }; let mut res = serde_json::to_string(&res).unwrap(); res.push('\n'); - if let Some(req) = req { cmd_tx.send(req.command).await.unwrap(); let res = rsp_rxc.recv().await.unwrap().with_id(req.id); @@ -117,6 +122,8 @@ pub(crate) async fn listen_with_optional_fd( retres.push('\n'); info!("Sending response: {}", retres); write_stream.write_all(retres.as_bytes()).await.unwrap(); + } else { + write_stream.write_all(res.as_bytes()).await.unwrap(); } } }); @@ -129,8 +136,7 @@ pub struct DaemonClient { impl DaemonClient { pub async fn new() -> Result { - let path = fetch_socket_path() - .ok_or(anyhow!("Failed to find socket path"))?; + let path = fetch_socket_path().ok_or(anyhow!("Failed to find socket path"))?; // debug!("found path: {:?}", path); let connection = UnixStream::connect(path).await?; debug!("connected to socket"); diff --git a/burrow/src/daemon/net/windows.rs b/burrow/src/daemon/net/windows.rs index 3f9d513..c734689 100644 --- a/burrow/src/daemon/net/windows.rs +++ b/burrow/src/daemon/net/windows.rs @@ -1,6 +1,9 @@ use super::*; -pub async fn listen(_cmd_tx: async_channel::Sender, _rsp_rx: async_channel::Receiver) -> Result<()> { +pub async fn listen( + _cmd_tx: async_channel::Sender, + _rsp_rx: async_channel::Receiver, +) -> Result<()> { unimplemented!("This platform does not currently support daemon mode.") } diff --git a/burrow/src/daemon/response.rs b/burrow/src/daemon/response.rs index da47150..63d10e8 100644 --- a/burrow/src/daemon/response.rs +++ b/burrow/src/daemon/response.rs @@ -7,30 +7,27 @@ use tun::TunInterface; pub struct DaemonResponse { // Error types can't be serialized, so this is the second best option. pub result: Result, - pub id: u32 + pub id: u32, } -impl DaemonResponse{ - pub fn new(result: Result) -> Self{ - Self{ +impl DaemonResponse { + pub fn new(result: Result) -> Self { + Self { result: result.map_err(|e| e.to_string()), - id: 0 + id: 0, } } } -impl Into for DaemonResponseData{ - fn into(self) -> DaemonResponse{ +impl Into for DaemonResponseData { + fn into(self) -> DaemonResponse { DaemonResponse::new(Ok::(self)) } } -impl DaemonResponse{ - pub fn with_id(self, id: u32) -> Self{ - Self { - id, - ..self - } +impl DaemonResponse { + pub fn with_id(self, id: u32) -> Self { + Self { id, ..self } } } @@ -38,24 +35,22 @@ impl DaemonResponse{ pub struct ServerInfo { pub name: Option, pub ip: Option, - pub mtu: Option + pub mtu: Option, } -impl TryFrom<&TunInterface> for ServerInfo{ +impl TryFrom<&TunInterface> for ServerInfo { type Error = anyhow::Error; - #[cfg(any(target_os="linux",target_vendor="apple"))] + #[cfg(any(target_os = "linux", target_vendor = "apple"))] fn try_from(server: &TunInterface) -> anyhow::Result { - Ok( - ServerInfo{ - name: server.name().ok(), - ip: server.ipv4_addr().ok().map(|ip| ip.to_string()), - mtu: server.mtu().ok() - } - ) + Ok(ServerInfo { + name: server.name().ok(), + ip: server.ipv4_addr().ok().map(|ip| ip.to_string()), + mtu: server.mtu().ok(), + }) } - #[cfg(not(any(target_os="linux",target_vendor="apple")))] + #[cfg(not(any(target_os = "linux", target_vendor = "apple")))] fn try_from(server: &TunInterface) -> anyhow::Result { Err(anyhow!("Not implemented in this platform")) } @@ -65,45 +60,55 @@ impl TryFrom<&TunInterface> for ServerInfo{ pub struct ServerConfig { pub address: Option, pub name: Option, - pub mtu: Option + pub mtu: Option, } impl Default for ServerConfig { fn default() -> Self { - Self{ - address: Some("10.0.0.1".to_string()), // Dummy remote address + Self { + address: Some("10.13.13.2".to_string()), // Dummy remote address name: None, - mtu: None + mtu: None, } } } #[derive(Clone, Debug, Serialize, Deserialize, JsonSchema)] -pub enum DaemonResponseData{ +pub enum DaemonResponseData { ServerInfo(ServerInfo), ServerConfig(ServerConfig), - None + None, } #[test] -fn test_response_serialization() -> anyhow::Result<()>{ - insta::assert_snapshot!( - serde_json::to_string(&DaemonResponse::new(Ok::(DaemonResponseData::None)))? - ); - insta::assert_snapshot!( - serde_json::to_string(&DaemonResponse::new(Ok::(DaemonResponseData::ServerInfo(ServerInfo{ +fn test_response_serialization() -> anyhow::Result<()> { + insta::assert_snapshot!(serde_json::to_string(&DaemonResponse::new(Ok::< + DaemonResponseData, + String, + >( + DaemonResponseData::None + )))?); + insta::assert_snapshot!(serde_json::to_string(&DaemonResponse::new(Ok::< + DaemonResponseData, + String, + >( + DaemonResponseData::ServerInfo(ServerInfo { name: Some("burrow".to_string()), ip: None, mtu: Some(1500) - }))))? - ); - insta::assert_snapshot!( - serde_json::to_string(&DaemonResponse::new(Err::("error".to_string())))? - ); - insta::assert_snapshot!( - serde_json::to_string(&DaemonResponse::new(Ok::(DaemonResponseData::ServerConfig( - ServerConfig::default() - ))))? - ); + }) + )))?); + insta::assert_snapshot!(serde_json::to_string(&DaemonResponse::new(Err::< + DaemonResponseData, + String, + >( + "error".to_string() + )))?); + insta::assert_snapshot!(serde_json::to_string(&DaemonResponse::new(Ok::< + DaemonResponseData, + String, + >( + DaemonResponseData::ServerConfig(ServerConfig::default()) + )))?); Ok(()) -} \ No newline at end of file +} diff --git a/burrow/src/daemon/snapshots/burrow__daemon__command__daemoncommand_serialization-2.snap b/burrow/src/daemon/snapshots/burrow__daemon__command__daemoncommand_serialization-2.snap index 80b9e24..289851f 100644 --- a/burrow/src/daemon/snapshots/burrow__daemon__command__daemoncommand_serialization-2.snap +++ b/burrow/src/daemon/snapshots/burrow__daemon__command__daemoncommand_serialization-2.snap @@ -1,5 +1,5 @@ --- source: burrow/src/daemon/command.rs -expression: "serde_json::to_string(&DaemonCommand::ServerInfo).unwrap()" +expression: "serde_json::to_string(&DaemonCommand::Start(DaemonStartOptions {\n tun: TunOptions { seek_utun: true, ..TunOptions::default() },\n })).unwrap()" --- -"ServerInfo" +{"Start":{"tun":{"name":null,"no_pi":false,"tun_excl":false,"seek_utun":true,"address":null}}} diff --git a/burrow/src/daemon/snapshots/burrow__daemon__command__daemoncommand_serialization-3.snap b/burrow/src/daemon/snapshots/burrow__daemon__command__daemoncommand_serialization-3.snap index 8dc1b8b..80b9e24 100644 --- a/burrow/src/daemon/snapshots/burrow__daemon__command__daemoncommand_serialization-3.snap +++ b/burrow/src/daemon/snapshots/burrow__daemon__command__daemoncommand_serialization-3.snap @@ -1,5 +1,5 @@ --- source: burrow/src/daemon/command.rs -expression: "serde_json::to_string(&DaemonCommand::Stop).unwrap()" +expression: "serde_json::to_string(&DaemonCommand::ServerInfo).unwrap()" --- -"Stop" +"ServerInfo" diff --git a/burrow/src/daemon/snapshots/burrow__daemon__command__daemoncommand_serialization-4.snap b/burrow/src/daemon/snapshots/burrow__daemon__command__daemoncommand_serialization-4.snap index 9334ece..8dc1b8b 100644 --- a/burrow/src/daemon/snapshots/burrow__daemon__command__daemoncommand_serialization-4.snap +++ b/burrow/src/daemon/snapshots/burrow__daemon__command__daemoncommand_serialization-4.snap @@ -1,5 +1,5 @@ --- source: burrow/src/daemon/command.rs -expression: "serde_json::to_string(&DaemonCommand::ServerConfig).unwrap()" +expression: "serde_json::to_string(&DaemonCommand::Stop).unwrap()" --- -"ServerConfig" +"Stop" diff --git a/burrow/src/daemon/snapshots/burrow__daemon__command__daemoncommand_serialization-5.snap b/burrow/src/daemon/snapshots/burrow__daemon__command__daemoncommand_serialization-5.snap new file mode 100644 index 0000000..9334ece --- /dev/null +++ b/burrow/src/daemon/snapshots/burrow__daemon__command__daemoncommand_serialization-5.snap @@ -0,0 +1,5 @@ +--- +source: burrow/src/daemon/command.rs +expression: "serde_json::to_string(&DaemonCommand::ServerConfig).unwrap()" +--- +"ServerConfig" diff --git a/burrow/src/daemon/snapshots/burrow__daemon__command__daemoncommand_serialization.snap b/burrow/src/daemon/snapshots/burrow__daemon__command__daemoncommand_serialization.snap index 2f8af66..ff32838 100644 --- a/burrow/src/daemon/snapshots/burrow__daemon__command__daemoncommand_serialization.snap +++ b/burrow/src/daemon/snapshots/burrow__daemon__command__daemoncommand_serialization.snap @@ -2,4 +2,4 @@ source: burrow/src/daemon/command.rs expression: "serde_json::to_string(&DaemonCommand::Start(DaemonStartOptions::default())).unwrap()" --- -{"Start":{"tun":{"name":null,"no_pi":null,"tun_excl":null}}} +{"Start":{"tun":{"name":null,"no_pi":false,"tun_excl":false,"seek_utun":false,"address":null}}} diff --git a/burrow/src/daemon/snapshots/burrow__daemon__response__response_serialization-4.snap b/burrow/src/daemon/snapshots/burrow__daemon__response__response_serialization-4.snap index 95f9e7b..9752ebc 100644 --- a/burrow/src/daemon/snapshots/burrow__daemon__response__response_serialization-4.snap +++ b/burrow/src/daemon/snapshots/burrow__daemon__response__response_serialization-4.snap @@ -2,4 +2,4 @@ source: burrow/src/daemon/response.rs expression: "serde_json::to_string(&DaemonResponse::new(Ok::(DaemonResponseData::ServerConfig(ServerConfig::default()))))?" --- -{"result":{"Ok":{"ServerConfig":{"address":"10.0.0.1","name":null,"mtu":null}}},"id":0} +{"result":{"Ok":{"ServerConfig":{"address":"10.13.13.2","name":null,"mtu":null}}},"id":0} diff --git a/burrow/src/lib.rs b/burrow/src/lib.rs index 030022d..07cb2f6 100644 --- a/burrow/src/lib.rs +++ b/burrow/src/lib.rs @@ -14,31 +14,12 @@ use tun::TunInterface; // TODO Separate start and retrieve functions mod daemon; -pub use daemon::{DaemonCommand, DaemonResponseData, DaemonStartOptions, DaemonResponse, ServerInfo}; +pub use daemon::{ + DaemonCommand, DaemonResponse, DaemonResponseData, DaemonStartOptions, ServerInfo, +}; #[cfg(target_vendor = "apple")] mod apple; #[cfg(target_vendor = "apple")] pub use apple::*; - -#[cfg(any(target_os = "linux", target_vendor = "apple"))] -#[no_mangle] -pub extern "C" fn retrieve() -> i32 { - let iface2 = (1..100) - .filter_map(|i| { - let iface = unsafe { TunInterface::from_raw_fd(i) }; - match iface.name() { - Ok(_name) => Some(iface), - Err(_) => { - mem::forget(iface); - None - } - } - }) - .next(); - match iface2 { - Some(iface) => iface.as_raw_fd(), - None => -1, - } -} diff --git a/burrow/src/main.rs b/burrow/src/main.rs index 125d763..ff0ed53 100644 --- a/burrow/src/main.rs +++ b/burrow/src/main.rs @@ -4,22 +4,20 @@ use std::os::fd::FromRawFd; use anyhow::{Context, Result}; #[cfg(any(target_os = "linux", target_vendor = "apple"))] -use burrow::retrieve; use clap::{Args, Parser, Subcommand}; use tracing::{instrument, Level}; use tracing_log::LogTracer; use tracing_oslog::OsLogger; -use tracing_subscriber::{prelude::*, FmtSubscriber, EnvFilter}; +use tracing_subscriber::{prelude::*, EnvFilter, FmtSubscriber}; #[cfg(any(target_os = "linux", target_vendor = "apple"))] -use tun::TunInterface; - +use tun::{retrieve, TunInterface}; mod daemon; mod wireguard; +use crate::daemon::DaemonResponseData; use daemon::{DaemonClient, DaemonCommand, DaemonStartOptions}; use tun::TunOptions; -use crate::daemon::DaemonResponseData; #[derive(Parser)] #[command(name = "Burrow")] @@ -66,11 +64,9 @@ struct DaemonArgs {} async fn try_start() -> Result<()> { let mut client = DaemonClient::new().await?; client - .send_command(DaemonCommand::Start( - DaemonStartOptions{ - tun: TunOptions::new().address("10.13.13.2") - } - )) + .send_command(DaemonCommand::Start(DaemonStartOptions { + tun: TunOptions::new().address("10.13.13.2"), + })) .await .map(|_| ()) } @@ -93,8 +89,8 @@ async fn try_retrieve() -> Result<()> { } burrow::ensureroot::ensure_root(); - let iface2 = retrieve(); - tracing::info!("{}", iface2); + let iface2 = retrieve().ok_or(anyhow::anyhow!("No interface found"))?; + tracing::info!("{:?}", iface2); Ok(()) } @@ -109,9 +105,10 @@ async fn initialize_tracing() -> Result<()> { FmtSubscriber::builder() .with_line_number(true) .with_env_filter(EnvFilter::from_default_env()) - .finish() + .finish(), ); - tracing::subscriber::set_global_default(logger).context("Failed to set the global tracing subscriber")?; + tracing::subscriber::set_global_default(logger) + .context("Failed to set the global tracing subscriber")?; } } @@ -126,7 +123,7 @@ async fn try_stop() -> Result<()> { } #[cfg(any(target_os = "linux", target_vendor = "apple"))] -async fn try_serverinfo() -> Result<()>{ +async fn try_serverinfo() -> Result<()> { let mut client = DaemonClient::new().await?; let res = client.send_command(DaemonCommand::ServerInfo).await?; match res.result { @@ -136,7 +133,9 @@ async fn try_serverinfo() -> Result<()>{ Ok(DaemonResponseData::None) => { println!("Server not started.") } - Ok(res) => {println!("Unexpected Response: {:?}", res)} + Ok(res) => { + println!("Unexpected Response: {:?}", res) + } Err(e) => { println!("Error when retrieving from server: {}", e) } @@ -145,7 +144,7 @@ async fn try_serverinfo() -> Result<()>{ } #[cfg(any(target_os = "linux", target_vendor = "apple"))] -async fn try_serverconfig() -> Result<()>{ +async fn try_serverconfig() -> Result<()> { let mut client = DaemonClient::new().await?; let res = client.send_command(DaemonCommand::ServerConfig).await?; match res.result { @@ -155,7 +154,9 @@ async fn try_serverconfig() -> Result<()>{ Ok(DaemonResponseData::None) => { println!("Server not started.") } - Ok(res) => {println!("Unexpected Response: {:?}", res)} + Ok(res) => { + println!("Unexpected Response: {:?}", res) + } Err(e) => { println!("Error when retrieving from server: {}", e) } @@ -206,12 +207,8 @@ async fn main() -> Result<()> { try_stop().await?; } Commands::Daemon(_) => daemon::daemon_main().await?, - Commands::ServerInfo => { - try_serverinfo().await? - } - Commands::ServerConfig => { - try_serverconfig().await? - } + Commands::ServerInfo => try_serverinfo().await?, + Commands::ServerConfig => try_serverconfig().await?, } Ok(()) diff --git a/burrow/src/wireguard/iface.rs b/burrow/src/wireguard/iface.rs index 4a00cbe..52f719b 100755 --- a/burrow/src/wireguard/iface.rs +++ b/burrow/src/wireguard/iface.rs @@ -1,21 +1,22 @@ -use std::{net::IpAddr, rc::Rc}; use std::sync::Arc; use std::time::Duration; +use std::{net::IpAddr, rc::Rc}; use anyhow::Error; use async_trait::async_trait; use fehler::throws; +use futures::future::join_all; +use futures::FutureExt; use ip_network_table::IpNetworkTable; use log::log; +use tokio::time::timeout; use tokio::{ join, sync::{Mutex, RwLock}, task::{self, JoinHandle}, }; +use tracing::{debug, error}; use tun::tokio::TunInterface; -use futures::future::join_all; -use futures::FutureExt; -use tokio::time::timeout; use super::{noise::Tunnel, pcb, Peer, PeerPcb}; @@ -90,10 +91,7 @@ impl Interface { .collect::>()?; let pcbs = Arc::new(pcbs); - Self { - pcbs, - tun: None - } + Self { pcbs, tun: None } } pub fn set_tun(&mut self, tun: Arc>) { @@ -101,66 +99,82 @@ impl Interface { } pub async fn run(&self) -> anyhow::Result<()> { + debug!("RUN: starting interface"); let pcbs = self.pcbs.clone(); - let tun = self.tun.clone().ok_or(anyhow::anyhow!("tun interface does not exist"))?; + let tun = self + .tun + .clone() + .ok_or(anyhow::anyhow!("tun interface does not exist"))?; log::info!("starting interface"); let outgoing = async move { loop { - // log::debug!("starting loop..."); + // tracing::debug!("starting loop..."); let mut buf = [0u8; 3000]; let src = { - // log::debug!("awaiting read..."); - let src = match timeout(Duration::from_millis(10), tun.write().await.recv(&mut buf[..])).await { + let src = match timeout( + Duration::from_millis(10), + tun.write().await.recv(&mut buf[..]), + ) + .await + { Ok(Ok(len)) => &buf[..len], - Ok(Err(e)) => {continue} + Ok(Err(e)) => { + error!("failed to read from interface: {}", e); + continue; + } Err(_would_block) => { - continue + debug!("read timed out"); + continue; } }; - log::debug!("read {} bytes from interface", src.len()); - log::debug!("bytes: {:?}", src); + debug!("read {} bytes from interface", src.len()); + debug!("bytes: {:?}", src); src }; - let dst_addr = match Tunnel::dst_address(src) { Some(addr) => addr, None => { - log::debug!("no destination found"); - continue - }, + tracing::debug!("no destination found"); + continue; + } }; - log::debug!("dst_addr: {}", dst_addr); + tracing::debug!("dst_addr: {}", dst_addr); let Some(idx) = pcbs.find(dst_addr) else { continue }; - log::debug!("found peer:{}", idx); + tracing::debug!("found peer:{}", idx); match pcbs.pcbs[idx].read().await.send(src).await { Ok(..) => { - log::debug!("sent packet to peer {}", dst_addr); + tracing::debug!("sent packet to peer {}", dst_addr); } Err(e) => { log::error!("failed to send packet {}", e); - continue - }, + continue; + } }; } }; let mut tsks = vec![]; - let tun = self.tun.clone().ok_or(anyhow::anyhow!("tun interface does not exist"))?; - + let tun = self + .tun + .clone() + .ok_or(anyhow::anyhow!("tun interface does not exist"))?; let outgoing = tokio::task::spawn(outgoing); tsks.push(outgoing); + debug!("preparing to spawn read tasks"); + { let pcbs = &self.pcbs; - for i in 0..pcbs.pcbs.len(){ + for i in 0..pcbs.pcbs.len() { + debug!("spawning read task for peer {}", i); let mut pcb = pcbs.pcbs[i].clone(); let tun = tun.clone(); let tsk = async move { @@ -168,23 +182,25 @@ impl Interface { let r1 = pcb.write().await.open_if_closed().await; if let Err(e) = r1 { log::error!("failed to open pcb: {}", e); - return + return; } } let r2 = pcb.read().await.run(tun).await; if let Err(e) = r2 { log::error!("failed to run pcb: {}", e); - return + return; } else { - log::debug!("pcb ran successfully"); + tracing::debug!("pcb ran successfully"); } }; + debug!("task made.."); tsks.push(tokio::spawn(tsk)); } - log::debug!("spawned read tasks"); + debug!("spawned read tasks"); } - log::debug!("preparing to join.."); + debug!("preparing to join.."); join_all(tsks).await; + debug!("joined!"); Ok(()) } } diff --git a/burrow/src/wireguard/noise/handshake.rs b/burrow/src/wireguard/noise/handshake.rs index 92c456c..3f8c91b 100755 --- a/burrow/src/wireguard/noise/handshake.rs +++ b/burrow/src/wireguard/noise/handshake.rs @@ -9,20 +9,14 @@ use std::{ use aead::{Aead, Payload}; use blake2::{ digest::{FixedOutput, KeyInit}, - Blake2s256, - Blake2sMac, - Digest, + Blake2s256, Blake2sMac, Digest, }; use chacha20poly1305::XChaCha20Poly1305; use rand_core::OsRng; use ring::aead::{Aad, LessSafeKey, Nonce, UnboundKey, CHACHA20_POLY1305}; use super::{ - errors::WireGuardError, - session::Session, - x25519, - HandshakeInit, - HandshakeResponse, + errors::WireGuardError, session::Session, x25519, HandshakeInit, HandshakeResponse, PacketCookieReply, }; @@ -137,8 +131,8 @@ fn aead_chacha20_open( let mut nonce: [u8; 12] = [0; 12]; nonce[4..].copy_from_slice(&counter.to_le_bytes()); - log::debug!("TAG A"); - log::debug!("{:?};{:?};{:?};{:?};{}", key, data, aad, nonce, counter); + tracing::debug!("TAG A"); + tracing::debug!("{:?};{:?};{:?};{:?};{}", key, data, aad, nonce, counter); aead_chacha20_open_inner(buffer, key, nonce, data, aad) .map_err(|_| WireGuardError::InvalidAeadTag)?; @@ -213,7 +207,7 @@ impl Tai64N { /// Parse a timestamp from a 12 byte u8 slice fn parse(buf: &[u8; 12]) -> Result { if buf.len() < 12 { - return Err(WireGuardError::InvalidTai64nTimestamp) + return Err(WireGuardError::InvalidTai64nTimestamp); } let (sec_bytes, nano_bytes) = buf.split_at(std::mem::size_of::()); @@ -560,19 +554,22 @@ impl Handshake { let timestamp = Tai64N::parse(×tamp)?; if !timestamp.after(&self.last_handshake_timestamp) { // Possibly a replay - return Err(WireGuardError::WrongTai64nTimestamp) + return Err(WireGuardError::WrongTai64nTimestamp); } self.last_handshake_timestamp = timestamp; // initiator.hash = HASH(initiator.hash || msg.encrypted_timestamp) hash = b2s_hash(&hash, packet.encrypted_timestamp); - self.previous = std::mem::replace(&mut self.state, HandshakeState::InitReceived { - chaining_key, - hash, - peer_ephemeral_public, - peer_index, - }); + self.previous = std::mem::replace( + &mut self.state, + HandshakeState::InitReceived { + chaining_key, + hash, + peer_ephemeral_public, + peer_index, + }, + ); self.format_handshake_response(dst) } @@ -673,7 +670,7 @@ impl Handshake { let local_index = self.cookies.index; if packet.receiver_idx != local_index { - return Err(WireGuardError::WrongIndex) + return Err(WireGuardError::WrongIndex); } // msg.encrypted_cookie = XAEAD(HASH(LABEL_COOKIE || responder.static_public), // msg.nonce, cookie, last_received_msg.mac1) @@ -683,7 +680,7 @@ impl Handshake { aad: &mac1[0..16], msg: packet.encrypted_cookie, }; - log::debug!("TAG B"); + tracing::debug!("TAG B"); let plaintext = XChaCha20Poly1305::new_from_slice(&key) .unwrap() .decrypt(packet.nonce.into(), payload) @@ -730,7 +727,7 @@ impl Handshake { dst: &'a mut [u8], ) -> Result<&'a mut [u8], WireGuardError> { if dst.len() < super::HANDSHAKE_INIT_SZ { - return Err(WireGuardError::DestinationBufferTooSmall) + return Err(WireGuardError::DestinationBufferTooSmall); } let (message_type, rest) = dst.split_at_mut(4); @@ -813,7 +810,7 @@ impl Handshake { dst: &'a mut [u8], ) -> Result<(&'a mut [u8], Session), WireGuardError> { if dst.len() < super::HANDSHAKE_RESP_SZ { - return Err(WireGuardError::DestinationBufferTooSmall) + return Err(WireGuardError::DestinationBufferTooSmall); } let state = std::mem::replace(&mut self.state, HandshakeState::None); diff --git a/burrow/src/wireguard/noise/mod.rs b/burrow/src/wireguard/noise/mod.rs index 7e2184d..824d7c1 100755 --- a/burrow/src/wireguard/noise/mod.rs +++ b/burrow/src/wireguard/noise/mod.rs @@ -45,11 +45,7 @@ const N_SESSIONS: usize = 8; pub mod x25519 { pub use x25519_dalek::{ - EphemeralSecret, - PublicKey, - ReusableSecret, - SharedSecret, - StaticSecret, + EphemeralSecret, PublicKey, ReusableSecret, SharedSecret, StaticSecret, }; } @@ -141,12 +137,12 @@ impl Tunnel { #[inline(always)] pub fn parse_incoming_packet(src: &[u8]) -> Result { if src.len() < 4 { - return Err(WireGuardError::InvalidPacket) + return Err(WireGuardError::InvalidPacket); } // Checks the type, as well as the reserved zero fields let packet_type = u32::from_le_bytes(src[0..4].try_into().unwrap()); - log::debug!("packet_type: {}", packet_type); + tracing::debug!("packet_type: {}", packet_type); Ok(match (packet_type, src.len()) { (HANDSHAKE_INIT, HANDSHAKE_INIT_SZ) => Packet::HandshakeInit(HandshakeInit { @@ -183,7 +179,7 @@ impl Tunnel { pub fn dst_address(packet: &[u8]) -> Option { if packet.is_empty() { - return None + return None; } match packet[0] >> 4 { @@ -278,7 +274,7 @@ impl Tunnel { self.timer_tick(TimerName::TimeLastDataPacketSent); } self.tx_bytes += src.len(); - return TunnResult::WriteToNetwork(packet) + return TunnResult::WriteToNetwork(packet); } // If there is no session, queue the packet for future retry @@ -302,7 +298,7 @@ impl Tunnel { ) -> TunnResult<'a> { if datagram.is_empty() { // Indicates a repeated call - return self.send_queued_packet(dst) + return self.send_queued_packet(dst); } let mut cookie = [0u8; COOKIE_REPLY_SZ]; @@ -313,7 +309,7 @@ impl Tunnel { Ok(packet) => packet, Err(TunnResult::WriteToNetwork(cookie)) => { dst[..cookie.len()].copy_from_slice(cookie); - return TunnResult::WriteToNetwork(&mut dst[..cookie.len()]) + return TunnResult::WriteToNetwork(&mut dst[..cookie.len()]); } Err(TunnResult::Err(e)) => return TunnResult::Err(e), _ => unreachable!(), @@ -413,7 +409,7 @@ impl Tunnel { let cur_idx = self.current; if cur_idx == new_idx { // There is nothing to do, already using this session, this is the common case - return + return; } if self.sessions[cur_idx % N_SESSIONS].is_none() || self.timers.session_timers[new_idx % N_SESSIONS] @@ -459,7 +455,7 @@ impl Tunnel { force_resend: bool, ) -> TunnResult<'a> { if self.handshake.is_in_progress() && !force_resend { - return TunnResult::Done + return TunnResult::Done; } if self.handshake.is_expired() { @@ -518,7 +514,7 @@ impl Tunnel { }; if computed_len > packet.len() { - return TunnResult::Err(WireGuardError::InvalidPacket) + return TunnResult::Err(WireGuardError::InvalidPacket); } self.timer_tick(TimerName::TimeLastDataPacketReceived); diff --git a/burrow/src/wireguard/noise/rate_limiter.rs b/burrow/src/wireguard/noise/rate_limiter.rs index df29f93..02887ee 100755 --- a/burrow/src/wireguard/noise/rate_limiter.rs +++ b/burrow/src/wireguard/noise/rate_limiter.rs @@ -13,19 +13,9 @@ use ring::constant_time::verify_slices_are_equal; use super::{ handshake::{ - b2s_hash, - b2s_keyed_mac_16, - b2s_keyed_mac_16_2, - b2s_mac_24, - LABEL_COOKIE, - LABEL_MAC1, + b2s_hash, b2s_keyed_mac_16, b2s_keyed_mac_16_2, b2s_mac_24, LABEL_COOKIE, LABEL_MAC1, }, - HandshakeInit, - HandshakeResponse, - Packet, - TunnResult, - Tunnel, - WireGuardError, + HandshakeInit, HandshakeResponse, Packet, TunnResult, Tunnel, WireGuardError, }; const COOKIE_REFRESH: u64 = 128; // Use 128 and not 120 so the compiler can optimize out the division @@ -137,7 +127,7 @@ impl RateLimiter { dst: &'a mut [u8], ) -> Result<&'a mut [u8], WireGuardError> { if dst.len() < super::COOKIE_REPLY_SZ { - return Err(WireGuardError::DestinationBufferTooSmall) + return Err(WireGuardError::DestinationBufferTooSmall); } let (message_type, rest) = dst.split_at_mut(4); @@ -174,14 +164,14 @@ impl RateLimiter { dst: &'b mut [u8], ) -> Result, TunnResult<'b>> { let packet = Tunnel::parse_incoming_packet(src)?; - log::debug!("packet: {:?}", packet); + tracing::debug!("packet: {:?}", packet); // Verify and rate limit handshake messages only if let Packet::HandshakeInit(HandshakeInit { sender_idx, .. }) | Packet::HandshakeResponse(HandshakeResponse { sender_idx, .. }) = packet { - log::debug!("sender_idx: {}", sender_idx); - log::debug!("response: {:?}", packet); + tracing::debug!("sender_idx: {}", sender_idx); + tracing::debug!("response: {:?}", packet); let (msg, macs) = src.split_at(src.len() - 32); let (mac1, mac2) = macs.split_at(16); @@ -203,7 +193,7 @@ impl RateLimiter { let cookie_packet = self .format_cookie_reply(sender_idx, cookie, mac1, dst) .map_err(TunnResult::Err)?; - return Err(TunnResult::WriteToNetwork(cookie_packet)) + return Err(TunnResult::WriteToNetwork(cookie_packet)); } } } diff --git a/burrow/src/wireguard/noise/session.rs b/burrow/src/wireguard/noise/session.rs index eb7dbef..14c191b 100755 --- a/burrow/src/wireguard/noise/session.rs +++ b/burrow/src/wireguard/noise/session.rs @@ -88,11 +88,11 @@ impl ReceivingKeyCounterValidator { fn will_accept(&self, counter: u64) -> Result<(), WireGuardError> { if counter >= self.next { // As long as the counter is growing no replay took place for sure - return Ok(()) + return Ok(()); } if counter + N_BITS < self.next { // Drop if too far back - return Err(WireGuardError::InvalidCounter) + return Err(WireGuardError::InvalidCounter); } if !self.check_bit(counter) { Ok(()) @@ -107,22 +107,22 @@ impl ReceivingKeyCounterValidator { fn mark_did_receive(&mut self, counter: u64) -> Result<(), WireGuardError> { if counter + N_BITS < self.next { // Drop if too far back - return Err(WireGuardError::InvalidCounter) + return Err(WireGuardError::InvalidCounter); } if counter == self.next { // Usually the packets arrive in order, in that case we simply mark the bit and // increment the counter self.set_bit(counter); self.next += 1; - return Ok(()) + return Ok(()); } if counter < self.next { // A packet arrived out of order, check if it is valid, and mark if self.check_bit(counter) { - return Err(WireGuardError::InvalidCounter) + return Err(WireGuardError::InvalidCounter); } self.set_bit(counter); - return Ok(()) + return Ok(()); } // Packets where dropped, or maybe reordered, skip them and mark unused if counter - self.next >= N_BITS { @@ -247,13 +247,13 @@ impl Session { panic!("The destination buffer is too small"); } if packet.receiver_idx != self.receiving_index { - return Err(WireGuardError::WrongIndex) + return Err(WireGuardError::WrongIndex); } // Don't reuse counters, in case this is a replay attack we want to quickly // check the counter without running expensive decryption self.receiving_counter_quick_check(packet.counter)?; - log::debug!("TAG C"); + tracing::debug!("TAG C"); let ret = { let mut nonce = [0u8; 12]; nonce[4..12].copy_from_slice(&packet.counter.to_le_bytes()); diff --git a/burrow/src/wireguard/noise/timers.rs b/burrow/src/wireguard/noise/timers.rs index 1d0cf1f..f713e6f 100755 --- a/burrow/src/wireguard/noise/timers.rs +++ b/burrow/src/wireguard/noise/timers.rs @@ -190,7 +190,7 @@ impl Tunnel { { if self.handshake.is_expired() { - return TunnResult::Err(WireGuardError::ConnectionExpired) + return TunnResult::Err(WireGuardError::ConnectionExpired); } // Clear cookie after COOKIE_EXPIRATION_TIME @@ -206,7 +206,7 @@ impl Tunnel { tracing::error!("CONNECTION_EXPIRED(REJECT_AFTER_TIME * 3)"); self.handshake.set_expired(); self.clear_all(); - return TunnResult::Err(WireGuardError::ConnectionExpired) + return TunnResult::Err(WireGuardError::ConnectionExpired); } if let Some(time_init_sent) = self.handshake.timer() { @@ -219,7 +219,7 @@ impl Tunnel { tracing::error!("CONNECTION_EXPIRED(REKEY_ATTEMPT_TIME)"); self.handshake.set_expired(); self.clear_all(); - return TunnResult::Err(WireGuardError::ConnectionExpired) + return TunnResult::Err(WireGuardError::ConnectionExpired); } if time_init_sent.elapsed() >= REKEY_TIMEOUT { @@ -299,11 +299,11 @@ impl Tunnel { } if handshake_initiation_required { - return self.format_handshake_initiation(dst, true) + return self.format_handshake_initiation(dst, true); } if keepalive_required { - return self.encapsulate(&[], dst) + return self.encapsulate(&[], dst); } TunnResult::Done diff --git a/burrow/src/wireguard/pcb.rs b/burrow/src/wireguard/pcb.rs index 6fcaa15..6acc8d8 100755 --- a/burrow/src/wireguard/pcb.rs +++ b/burrow/src/wireguard/pcb.rs @@ -9,11 +9,11 @@ use fehler::throws; use ip_network::IpNetwork; use log::log; use rand::random; -use tokio::{net::UdpSocket, task::JoinHandle}; use tokio::sync::{Mutex, RwLock}; use tokio::time::timeout; -use uuid::uuid; +use tokio::{net::UdpSocket, task::JoinHandle}; use tun::tokio::TunInterface; +use uuid::uuid; use super::{ iface::PacketInterface, @@ -33,15 +33,24 @@ pub struct PeerPcb { impl PeerPcb { #[throws] pub fn new(peer: Peer) -> Self { - let tunnel = RwLock::new(Tunnel::new(peer.private_key, peer.public_key, peer.preshared_key, None, 1, None) - .map_err(|s| anyhow::anyhow!("{}", s))?); + let tunnel = RwLock::new( + Tunnel::new( + peer.private_key, + peer.public_key, + peer.preshared_key, + None, + 1, + None, + ) + .map_err(|s| anyhow::anyhow!("{}", s))?, + ); Self { endpoint: peer.endpoint, allowed_ips: peer.allowed_ips, handle: None, socket: None, - tunnel + tunnel, } } @@ -56,7 +65,7 @@ impl PeerPcb { pub async fn run(&self, tun_interface: Arc>) -> Result<(), Error> { let mut buf = [0u8; 3000]; - log::debug!("starting read loop for pcb..."); + tracing::debug!("starting read loop for pcb..."); loop { tracing::debug!("waiting for packet"); let len = self.recv(&mut buf, tun_interface.clone()).await?; @@ -64,29 +73,38 @@ impl PeerPcb { } } - pub async fn recv(&self, buf: &mut [u8], tun_interface: Arc>) -> Result { - log::debug!("starting read loop for pcb... for {:?}", &self); + pub async fn recv( + &self, + buf: &mut [u8], + tun_interface: Arc>, + ) -> Result { + tracing::debug!("starting read loop for pcb... for {:?}", &self); let rid: i32 = random(); - log::debug!("start read loop {}", rid); - loop{ - log::debug!("{}: waiting for packet", rid); + tracing::debug!("start read loop {}", rid); + loop { + tracing::debug!("{}: waiting for packet", rid); let Some(socket) = &self.socket else { continue }; - let mut res_buf = [0;1500]; - // log::debug!("{} : waiting for readability on {:?}", rid, socket); + let mut res_buf = [0; 1500]; + // tracing::debug!("{} : waiting for readability on {:?}", rid, socket); let len = match socket.recv(&mut res_buf).await { - Ok(l) => {l} + Ok(l) => l, Err(e) => { log::error!("{}: error reading from socket: {:?}", rid, e); - continue + continue; } }; let mut res_dat = &res_buf[..len]; tracing::debug!("{}: Decapsulating {} bytes", rid, len); tracing::debug!("{:?}", &res_dat); loop { - match self.tunnel.write().await.decapsulate(None, res_dat, &mut buf[..]) { + match self + .tunnel + .write() + .await + .decapsulate(None, res_dat, &mut buf[..]) + { TunnResult::Done => { break; } @@ -113,7 +131,7 @@ impl PeerPcb { } } } - return Ok(len) + return Ok(len); } } @@ -122,7 +140,6 @@ impl PeerPcb { Ok(self.socket.as_ref().expect("socket was just opened")) } - pub async fn send(&self, src: &[u8]) -> Result<(), Error> { let mut dst_buf = [0u8; 3000]; match self.tunnel.write().await.encapsulate(src, &mut dst_buf[..]) { diff --git a/burrow/src/wireguard/peer.rs b/burrow/src/wireguard/peer.rs index cc8a296..27c5399 100755 --- a/burrow/src/wireguard/peer.rs +++ b/burrow/src/wireguard/peer.rs @@ -10,7 +10,7 @@ pub struct Peer { pub private_key: StaticSecret, pub public_key: PublicKey, pub allowed_ips: Vec, - pub preshared_key: Option<[u8; 32]> + pub preshared_key: Option<[u8; 32]>, } impl fmt::Debug for Peer { diff --git a/tun/build.rs b/tun/build.rs index 8da8a40..03ee131 100644 --- a/tun/build.rs +++ b/tun/build.rs @@ -26,7 +26,7 @@ async fn generate(out_dir: &std::path::Path) -> anyhow::Result<()> { println!("cargo:rerun-if-changed={}", binary_path.to_str().unwrap()); if let (Ok(..), Ok(..)) = (File::open(&bindings_path), File::open(&binary_path)) { - return Ok(()) + return Ok(()); }; let archive = download(out_dir) diff --git a/tun/src/lib.rs b/tun/src/lib.rs index a1ca636..64e17df 100644 --- a/tun/src/lib.rs +++ b/tun/src/lib.rs @@ -15,4 +15,4 @@ mod options; pub mod tokio; pub use options::TunOptions; -pub use os_imp::{TunInterface, TunQueue}; +pub use os_imp::{retrieve, TunInterface, TunQueue}; diff --git a/tun/src/options.rs b/tun/src/options.rs index 82cadfd..aafdad2 100644 --- a/tun/src/options.rs +++ b/tun/src/options.rs @@ -5,31 +5,42 @@ use fehler::throws; use super::TunInterface; #[derive(Debug, Clone, Default)] -#[cfg_attr(feature = "serde", derive(serde::Serialize, serde::Deserialize, schemars::JsonSchema))] +#[cfg_attr( + feature = "serde", + derive(serde::Serialize, serde::Deserialize, schemars::JsonSchema) +)] pub struct TunOptions { /// (Windows + Linux) Name the tun interface. - pub(crate) name: Option, + pub name: Option, /// (Linux) Don't include packet information. - pub(crate) no_pi: Option<()>, + pub no_pi: bool, /// (Linux) Avoid opening an existing persistant device. - pub(crate) tun_excl: Option<()>, + pub tun_excl: bool, /// (MacOS) Whether to seek the first available utun device. - pub(crate) seek_utun: Option<()>, + pub seek_utun: Option, /// (Linux) The IP address of the tun interface. - pub(crate) address: Option, + pub address: Option, } impl TunOptions { - pub fn new() -> Self { Self::default() } + pub fn new() -> Self { + Self::default() + } pub fn name(mut self, name: &str) -> Self { self.name = Some(name.to_owned()); self } - pub fn no_pi(mut self, enable: bool) { self.no_pi = enable.then_some(()); } + pub fn no_pi(mut self, enable: bool) -> Self { + self.no_pi = enable; + self + } - pub fn tun_excl(mut self, enable: bool) { self.tun_excl = enable.then_some(()); } + pub fn tun_excl(mut self, enable: bool) -> Self { + self.tun_excl = enable; + self + } pub fn address(mut self, address: impl ToString) -> Self { self.address = Some(address.to_string()); @@ -37,5 +48,7 @@ impl TunOptions { } #[throws] - pub fn open(self) -> TunInterface { TunInterface::new_with_options(self)? } + pub fn open(self) -> TunInterface { + TunInterface::new_with_options(self)? + } } diff --git a/tun/src/tokio/mod.rs b/tun/src/tokio/mod.rs index fb924ff..c901cba 100644 --- a/tun/src/tokio/mod.rs +++ b/tun/src/tokio/mod.rs @@ -12,9 +12,7 @@ impl TunInterface { #[instrument] pub fn new(mut tun: crate::TunInterface) -> io::Result { tun.set_nonblocking(true)?; - Ok(Self { - inner: AsyncFd::new(tun)?, - }) + Ok(Self { inner: AsyncFd::new(tun)? }) } #[instrument] @@ -31,22 +29,22 @@ impl TunInterface { // #[instrument] pub async fn recv(&mut self, buf: &mut [u8]) -> io::Result { loop { - // log::debug!("TunInterface receiving..."); + // tracing::debug!("TunInterface receiving..."); let mut guard = self.inner.readable_mut().await?; - // log::debug!("Got! readable_mut"); + // tracing::debug!("Got! readable_mut"); match guard.try_io(|inner| { let raw_ref = (*inner).get_mut(); let recved = raw_ref.recv(buf); recved }) { Ok(result) => { - log::debug!("HORRAY"); - return result - }, + tracing::debug!("HORRAY"); + return result; + } Err(_would_block) => { - log::debug!("WouldBlock"); - continue - }, + tracing::debug!("WouldBlock"); + continue; + } } } } diff --git a/tun/src/unix/apple/mod.rs b/tun/src/unix/apple/mod.rs index b419294..ca3ddc7 100644 --- a/tun/src/unix/apple/mod.rs +++ b/tun/src/unix/apple/mod.rs @@ -9,11 +9,12 @@ use byteorder::{ByteOrder, NetworkEndian}; use fehler::throws; use libc::{c_char, iovec, writev, AF_INET, AF_INET6}; use socket2::{Domain, SockAddr, Socket, Type}; -use tracing::{self, instrument}; +use tracing::{self, debug, instrument}; mod kern_control; -mod sys; +pub mod sys; +use crate::retrieve; use kern_control::SysControlSocket; pub use super::queue::TunQueue; @@ -34,8 +35,13 @@ impl TunInterface { #[throws] #[instrument] pub fn new_with_options(options: TunOptions) -> TunInterface { - let ti = TunInterface::connect(0)?; - if let Some(addr) = options.address{ + debug!("Opening tun interface with options: {:?}", &options); + let ti = if let Some(n) = options.seek_utun { + retrieve().ok_or(Error::new(std::io::ErrorKind::NotFound, "No utun found"))? + } else { + TunInterface::connect(0)? + }; + if let Some(addr) = options.address { if let Ok(addr) = addr.parse() { ti.set_ipv4_addr(addr)?; } diff --git a/tun/src/unix/apple/sys.rs b/tun/src/unix/apple/sys.rs index b4d4a6a..c0ea613 100644 --- a/tun/src/unix/apple/sys.rs +++ b/tun/src/unix/apple/sys.rs @@ -2,20 +2,11 @@ use std::mem; use libc::{c_char, c_int, c_short, c_uint, c_ulong, sockaddr}; pub use libc::{ - c_void, - sockaddr_ctl, - sockaddr_in, - socklen_t, - AF_SYSTEM, - AF_SYS_CONTROL, - IFNAMSIZ, + c_void, sockaddr_ctl, sockaddr_in, socklen_t, AF_SYSTEM, AF_SYS_CONTROL, IFNAMSIZ, SYSPROTO_CONTROL, }; use nix::{ - ioctl_read_bad, - ioctl_readwrite, - ioctl_write_ptr_bad, - request_code_readwrite, + ioctl_read_bad, ioctl_readwrite, ioctl_write_ptr_bad, request_code_readwrite, request_code_write, }; diff --git a/tun/src/unix/linux/mod.rs b/tun/src/unix/linux/mod.rs index 90cf353..b4b5b8c 100644 --- a/tun/src/unix/linux/mod.rs +++ b/tun/src/unix/linux/mod.rs @@ -26,7 +26,9 @@ pub struct TunInterface { impl TunInterface { #[throws] #[instrument] - pub fn new() -> TunInterface { Self::new_with_options(TunOptions::new())? } + pub fn new() -> TunInterface { + Self::new_with_options(TunOptions::new())? + } #[throws] #[instrument] @@ -212,5 +214,7 @@ impl TunInterface { #[throws] #[instrument] - pub fn send(&self, buf: &[u8]) -> usize { self.socket.send(buf)? } + pub fn send(&self, buf: &[u8]) -> usize { + self.socket.send(buf)? + } } diff --git a/tun/src/unix/mod.rs b/tun/src/unix/mod.rs index bd9ffb4..72a8795 100644 --- a/tun/src/unix/mod.rs +++ b/tun/src/unix/mod.rs @@ -1,11 +1,13 @@ +use std::mem::size_of; use std::{ io::{Error, Read}, + mem, os::fd::{AsRawFd, FromRawFd, IntoRawFd, RawFd}, }; -use tracing::instrument; +use tracing::{debug, error, instrument}; -use super::TunOptions; +use super::{syscall, TunOptions}; mod queue; @@ -17,9 +19,13 @@ mod imp; #[path = "linux/mod.rs"] mod imp; +use crate::os_imp::imp::sys; +use crate::os_imp::imp::sys::resolve_ctl_info; use fehler::throws; pub use imp::TunInterface; +use libc::{getpeername, sockaddr_ctl, sockaddr_storage, socklen_t, AF_SYSTEM, AF_SYS_CONTROL}; pub use queue::TunQueue; +use socket2::SockAddr; impl AsRawFd for TunInterface { fn as_raw_fd(&self) -> RawFd { @@ -47,8 +53,8 @@ impl TunInterface { // there might be a more efficient way to implement this let tmp_buf = &mut [0u8; 1500]; let len = self.socket.read(tmp_buf)?; - buf[..len-4].copy_from_slice(&tmp_buf[4..len]); - len-4 + buf[..len - 4].copy_from_slice(&tmp_buf[4..len]); + len - 4 } #[throws] @@ -76,3 +82,35 @@ pub fn string_to_ifname(name: &str) -> [libc::c_char; libc::IFNAMSIZ] { buf[..len].copy_from_slice(unsafe { &*(name.as_bytes() as *const _ as *const [libc::c_char]) }); buf } + +#[cfg(any(target_os = "linux", target_vendor = "apple"))] +pub fn retrieve() -> Option { + (3..100) + .filter_map(|i| { + let result = unsafe { + let mut addr = sockaddr_ctl { + sc_len: size_of::() as u8, + sc_family: 0, + ss_sysaddr: 0, + sc_id: 0, + sc_unit: 0, + sc_reserved: Default::default(), + }; + let mut len = mem::size_of::() as libc::socklen_t; + let res = syscall!(getpeername(i, &mut addr as *mut _ as *mut _, len as *mut _)); + tracing::debug!("getpeername{}: {:?}", i, res); + if res.is_err() { + return None; + } + if addr.sc_family == sys::AF_SYSTEM as u8 + && addr.ss_sysaddr == sys::AF_SYS_CONTROL as u16 + { + Some(TunInterface::from_raw_fd(i)) + } else { + None + } + }; + result + }) + .next() +} diff --git a/tun/src/windows/mod.rs b/tun/src/windows/mod.rs index 9b6d5ad..dadd53f 100644 --- a/tun/src/windows/mod.rs +++ b/tun/src/windows/mod.rs @@ -25,7 +25,9 @@ impl Debug for TunInterface { impl TunInterface { #[throws] - pub fn new() -> TunInterface { Self::new_with_options(TunOptions::new())? } + pub fn new() -> TunInterface { + Self::new_with_options(TunOptions::new())? + } #[throws] pub(crate) fn new_with_options(options: TunOptions) -> TunInterface { @@ -37,17 +39,18 @@ impl TunInterface { if handle.is_null() { unsafe { GetLastError() }.ok()? } - TunInterface { - handle, - name: name_owned, - } + TunInterface { handle, name: name_owned } } - pub fn name(&self) -> String { self.name.clone() } + pub fn name(&self) -> String { + self.name.clone() + } } impl Drop for TunInterface { - fn drop(&mut self) { unsafe { sys::WINTUN.WintunCloseAdapter(self.handle) } } + fn drop(&mut self) { + unsafe { sys::WINTUN.WintunCloseAdapter(self.handle) } + } } pub(crate) mod sys { diff --git a/tun/tests/configure.rs b/tun/tests/configure.rs index 6ef597b..e7e2c6d 100644 --- a/tun/tests/configure.rs +++ b/tun/tests/configure.rs @@ -5,7 +5,9 @@ use tun::TunInterface; #[test] #[throws] -fn test_create() { TunInterface::new()?; } +fn test_create() { + TunInterface::new()?; +} #[test] #[throws] From 30cd00fc2b0136057a9dd744f171b68216390ac3 Mon Sep 17 00:00:00 2001 From: Conrad Kramer Date: Sat, 9 Dec 2023 18:37:13 -0800 Subject: [PATCH 09/10] Conrad's changes --- .../PacketTunnelProvider.swift | 76 +++++++++---------- burrow/src/apple.rs | 4 +- burrow/src/daemon/command.rs | 5 +- burrow/src/daemon/instance.rs | 21 +++-- burrow/src/daemon/mod.rs | 19 +++-- burrow/src/daemon/net/apple.rs | 7 +- burrow/src/daemon/net/unix.rs | 20 ++--- burrow/src/daemon/response.rs | 1 - burrow/src/ensureroot.rs | 40 ---------- burrow/src/lib.rs | 19 ++--- burrow/src/main.rs | 14 ++-- burrow/src/wireguard/iface.rs | 34 +++------ burrow/src/wireguard/noise/handshake.rs | 40 +++++----- burrow/src/wireguard/noise/mod.rs | 22 +++--- burrow/src/wireguard/noise/rate_limiter.rs | 19 +++-- burrow/src/wireguard/noise/session.rs | 14 ++-- burrow/src/wireguard/noise/timers.rs | 10 +-- burrow/src/wireguard/pcb.rs | 32 +++----- burrow/src/wireguard/peer.rs | 2 - tun/src/lib.rs | 2 +- tun/src/options.rs | 2 - tun/src/tokio/mod.rs | 4 +- tun/src/unix/apple/mod.rs | 51 +++++++++---- tun/src/unix/mod.rs | 42 +--------- 24 files changed, 206 insertions(+), 294 deletions(-) delete mode 100644 burrow/src/ensureroot.rs diff --git a/Apple/NetworkExtension/PacketTunnelProvider.swift b/Apple/NetworkExtension/PacketTunnelProvider.swift index 8260aa0..e9c48dd 100644 --- a/Apple/NetworkExtension/PacketTunnelProvider.swift +++ b/Apple/NetworkExtension/PacketTunnelProvider.swift @@ -6,7 +6,7 @@ class PacketTunnelProvider: NEPacketTunnelProvider { let logger = Logger(subsystem: "com.hackclub.burrow", category: "frontend") var client: BurrowIpc? var osInitialized = false - override func startTunnel(options: [String: NSObject]?, completionHandler: @escaping (Error?) -> Void) { + override func startTunnel(options: [String : NSObject]? = nil) async throws { logger.log("Starting tunnel") if !osInitialized { libburrow.initialize_oslog() @@ -15,38 +15,35 @@ class PacketTunnelProvider: NEPacketTunnelProvider { libburrow.start_srv() client = BurrowIpc(logger: logger) logger.info("Started server") - Task { - do { - let command = BurrowSingleCommand(id: 0, command: "ServerConfig") - guard let data = try await client?.request(command, type: Response>.self) - else { - throw BurrowError.cantParseResult - } - let encoded = try JSONEncoder().encode(data.result) - self.logger.log("Received final data: \(String(decoding: encoded, as: UTF8.self))") - guard let serverconfig = data.result.Ok else { - throw BurrowError.resultIsError - } - guard let tunNs = self.generateTunSettings(from: serverconfig) else { - throw BurrowError.addrDoesntExist - } - try await self.setTunnelNetworkSettings(tunNs) - self.logger.info("Set remote tunnel address to \(tunNs.tunnelRemoteAddress)") - -// let tunFd = self.packetFlow.value(forKeyPath: "socket.fileDescriptor") as! Int; -// self.logger.info("Found File Descriptor: \(tunFd)") - let start_command = start_req_fd(id: 1, fd: 0) - guard let data = try await client?.request(start_command, type: Response>.self) - else { - throw BurrowError.cantParseResult - } - let encoded_startres = try JSONEncoder().encode(data.result) - self.logger.log("Received start server response: \(String(decoding: encoded_startres, as: UTF8.self))") - completionHandler(nil) - } catch { - self.logger.error("An error occurred: \(error)") - completionHandler(error) + do { + let command = BurrowSingleCommand(id: 0, command: "ServerConfig") + guard let data = try await client?.request(command, type: Response>.self) + else { + throw BurrowError.cantParseResult } + let encoded = try JSONEncoder().encode(data.result) + self.logger.log("Received final data: \(String(decoding: encoded, as: UTF8.self))") + guard let serverconfig = data.result.Ok else { + throw BurrowError.resultIsError + } + guard let tunNs = self.generateTunSettings(from: serverconfig) else { + throw BurrowError.addrDoesntExist + } + try await self.setTunnelNetworkSettings(tunNs) + self.logger.info("Set remote tunnel address to \(tunNs.tunnelRemoteAddress)") + + // let tunFd = self.packetFlow.value(forKeyPath: "socket.fileDescriptor") as! Int; + // self.logger.info("Found File Descriptor: \(tunFd)") + let start_command = start_req_fd(id: 1, fd: 0) + guard let data = try await client?.request(start_command, type: Response>.self) + else { + throw BurrowError.cantParseResult + } + let encoded_startres = try JSONEncoder().encode(data.result) + self.logger.log("Received start server response: \(String(decoding: encoded_startres, as: UTF8.self))") + } catch { + self.logger.error("An error occurred: \(error)") + throw error } } private func generateTunSettings(from: ServerConfigData) -> NETunnelNetworkSettings? { @@ -60,17 +57,16 @@ class PacketTunnelProvider: NEPacketTunnelProvider { logger.log("Initialized ipv4 settings: \(nst.ipv4Settings)") return nst } - override func stopTunnel(with reason: NEProviderStopReason, completionHandler: @escaping () -> Void) { - completionHandler() + override func stopTunnel(with reason: NEProviderStopReason) async { + } - override func handleAppMessage(_ messageData: Data, completionHandler: ((Data?) -> Void)?) { - if let handler = completionHandler { - handler(messageData) - } + override func handleAppMessage(_ messageData: Data) async -> Data? { + messageData } - override func sleep(completionHandler: @escaping () -> Void) { - completionHandler() + override func sleep() async { + } override func wake() { + } } diff --git a/burrow/src/apple.rs b/burrow/src/apple.rs index dd50fc2..571b413 100644 --- a/burrow/src/apple.rs +++ b/burrow/src/apple.rs @@ -1,8 +1,6 @@ -use tracing::instrument::WithSubscriber; -use tracing::{debug, Subscriber}; +use tracing::debug; use tracing_oslog::OsLogger; use tracing_subscriber::layer::SubscriberExt; -use tracing_subscriber::FmtSubscriber; pub use crate::daemon::start_srv; diff --git a/burrow/src/daemon/command.rs b/burrow/src/daemon/command.rs index 776e172..53b4108 100644 --- a/burrow/src/daemon/command.rs +++ b/burrow/src/daemon/command.rs @@ -23,10 +23,7 @@ fn test_daemoncommand_serialization() { .unwrap()); insta::assert_snapshot!( serde_json::to_string(&DaemonCommand::Start(DaemonStartOptions { - tun: TunOptions { - seek_utun: true, - ..TunOptions::default() - } + tun: TunOptions { ..TunOptions::default() } })) .unwrap() ); diff --git a/burrow/src/daemon/instance.rs b/burrow/src/daemon/instance.rs index c79da05..98052d2 100644 --- a/burrow/src/daemon/instance.rs +++ b/burrow/src/daemon/instance.rs @@ -1,10 +1,17 @@ -use super::*; -use crate::daemon::response::{DaemonResponseData, ServerConfig, ServerInfo}; -use tokio::task::JoinHandle; -use tracing::field::debug; +use std::sync::Arc; + +use anyhow::Result; +use tokio::{sync::RwLock, task::JoinHandle}; use tracing::{debug, info, warn}; use tun::tokio::TunInterface; -use DaemonResponse; + +use crate::{ + daemon::{ + command::DaemonCommand, + response::{DaemonResponse, DaemonResponseData, ServerConfig, ServerInfo}, + }, + wireguard::Interface, +}; enum RunState { Running(JoinHandle>), @@ -48,7 +55,9 @@ impl DaemonInstance { } RunState::Idle => { debug!("Creating new TunInterface"); - let tun_if = Arc::new(RwLock::new(TunInterface::new(st.tun.open()?)?)); + let retrieved = TunInterface::new(tun::TunInterface::retrieve().unwrap())?; + let tun_if = Arc::new(RwLock::new(retrieved)); + // let tun_if = Arc::new(RwLock::new(TunInterface::new(st.tun.open()?)?)); debug!("TunInterface created: {:?}", tun_if); debug!("Setting tun_interface"); diff --git a/burrow/src/daemon/mod.rs b/burrow/src/daemon/mod.rs index 1aa6ea4..394ebec 100644 --- a/burrow/src/daemon/mod.rs +++ b/burrow/src/daemon/mod.rs @@ -1,26 +1,29 @@ -use std::net::{IpAddr, Ipv4Addr, SocketAddr, SocketAddrV4, ToSocketAddrs}; -use std::sync::Arc; +use std::{ + net::{IpAddr, Ipv4Addr, SocketAddr}, + sync::Arc, +}; mod command; mod instance; mod net; mod response; -use crate::wireguard::{Interface, Peer, PublicKey, StaticSecret}; use anyhow::{Error, Result}; use base64::{engine::general_purpose, Engine as _}; pub use command::{DaemonCommand, DaemonStartOptions}; use fehler::throws; use instance::DaemonInstance; use ip_network::{IpNetwork, Ipv4Network}; -pub use net::DaemonClient; -use tokio::sync::RwLock; - #[cfg(target_vendor = "apple")] pub use net::start_srv; - -use crate::daemon::net::listen; +pub use net::DaemonClient; pub use response::{DaemonResponse, DaemonResponseData, ServerInfo}; +use tokio::sync::RwLock; + +use crate::{ + daemon::net::listen, + wireguard::{Interface, Peer, PublicKey, StaticSecret}, +}; #[throws] fn parse_key(string: &str) -> [u8; 32] { diff --git a/burrow/src/daemon/net/apple.rs b/burrow/src/daemon/net/apple.rs index 1242dfe..b84ec08 100644 --- a/burrow/src/daemon/net/apple.rs +++ b/burrow/src/daemon/net/apple.rs @@ -1,9 +1,10 @@ -use crate::daemon::{daemon_main, DaemonClient}; -use std::future::Future; use std::thread; + use tokio::runtime::Runtime; use tracing::{error, info}; +use crate::daemon::{daemon_main, DaemonClient}; + #[no_mangle] pub extern "C" fn start_srv() { info!("Rust: Starting server"); @@ -20,7 +21,7 @@ pub extern "C" fn start_srv() { loop { match DaemonClient::new().await { Ok(_) => break, - Err(e) => { + Err(_e) => { // error!("Error when connecting to daemon: {}", e) } } diff --git a/burrow/src/daemon/net/unix.rs b/burrow/src/daemon/net/unix.rs index d0e5b26..7ce8992 100644 --- a/burrow/src/daemon/net/unix.rs +++ b/burrow/src/daemon/net/unix.rs @@ -1,25 +1,21 @@ -use super::*; -use anyhow::anyhow; -use log::log; -use std::hash::Hash; -use std::path::PathBuf; use std::{ - ascii, io, + io, os::{ fd::{FromRawFd, RawFd}, unix::net::UnixListener as StdUnixListener, }, - path::Path, + path::{Path, PathBuf}, }; -use tracing::info; -use crate::daemon::{DaemonCommand, DaemonResponse, DaemonResponseData}; -use anyhow::Result; +use anyhow::{anyhow, Result}; use tokio::{ io::{AsyncBufReadExt, AsyncWriteExt, BufReader}, net::{UnixListener, UnixStream}, }; -use tracing::debug; +use tracing::{debug, info}; + +use super::*; +use crate::daemon::{DaemonCommand, DaemonResponse, DaemonResponseData}; #[cfg(not(target_vendor = "apple"))] const UNIX_SOCKET_PATH: &str = "/run/burrow.sock"; @@ -40,7 +36,7 @@ fn fetch_socket_path() -> Option { for path in tries { let path = PathBuf::from(path); if path.exists() { - return Some(path); + return Some(path) } } None diff --git a/burrow/src/daemon/response.rs b/burrow/src/daemon/response.rs index 63d10e8..4bebe14 100644 --- a/burrow/src/daemon/response.rs +++ b/burrow/src/daemon/response.rs @@ -1,4 +1,3 @@ -use anyhow::anyhow; use schemars::JsonSchema; use serde::{Deserialize, Serialize}; use tun::TunInterface; diff --git a/burrow/src/ensureroot.rs b/burrow/src/ensureroot.rs deleted file mode 100644 index b7c0757..0000000 --- a/burrow/src/ensureroot.rs +++ /dev/null @@ -1,40 +0,0 @@ -use tracing::instrument; - -// Check capabilities on Linux -#[cfg(target_os = "linux")] -#[instrument] -pub fn ensure_root() { - use caps::{has_cap, CapSet, Capability}; - - let cap_net_admin = Capability::CAP_NET_ADMIN; - if let Ok(has_cap) = has_cap(None, CapSet::Effective, cap_net_admin) { - if !has_cap { - eprintln!( - "This action needs the CAP_NET_ADMIN permission. Did you mean to run it as root?" - ); - std::process::exit(77); - } - } else { - eprintln!("Failed to check capabilities. Please file a bug report!"); - std::process::exit(71); - } -} - -// Check for root user on macOS -#[cfg(target_vendor = "apple")] -#[instrument] -pub fn ensure_root() { - use nix::unistd::Uid; - - let current_uid = Uid::current(); - if !current_uid.is_root() { - eprintln!("This action must be run as root!"); - std::process::exit(77); - } -} - -#[cfg(target_family = "windows")] -#[instrument] -pub fn ensure_root() { - todo!() -} diff --git a/burrow/src/lib.rs b/burrow/src/lib.rs index 07cb2f6..9df60f0 100644 --- a/burrow/src/lib.rs +++ b/burrow/src/lib.rs @@ -1,21 +1,12 @@ -pub mod ensureroot; pub mod wireguard; -use anyhow::Result; - -#[cfg(any(target_os = "linux", target_vendor = "apple"))] -use std::{ - mem, - os::fd::{AsRawFd, FromRawFd}, -}; - -use tun::TunInterface; - -// TODO Separate start and retrieve functions - mod daemon; pub use daemon::{ - DaemonCommand, DaemonResponse, DaemonResponseData, DaemonStartOptions, ServerInfo, + DaemonCommand, + DaemonResponse, + DaemonResponseData, + DaemonStartOptions, + ServerInfo, }; #[cfg(target_vendor = "apple")] diff --git a/burrow/src/main.rs b/burrow/src/main.rs index ff0ed53..5003277 100644 --- a/burrow/src/main.rs +++ b/burrow/src/main.rs @@ -1,24 +1,21 @@ -use std::mem; -#[cfg(any(target_os = "linux", target_vendor = "apple"))] -use std::os::fd::FromRawFd; - use anyhow::{Context, Result}; #[cfg(any(target_os = "linux", target_vendor = "apple"))] use clap::{Args, Parser, Subcommand}; -use tracing::{instrument, Level}; +use tracing::instrument; use tracing_log::LogTracer; use tracing_oslog::OsLogger; use tracing_subscriber::{prelude::*, EnvFilter, FmtSubscriber}; #[cfg(any(target_os = "linux", target_vendor = "apple"))] -use tun::{retrieve, TunInterface}; +use tun::TunInterface; mod daemon; mod wireguard; -use crate::daemon::DaemonResponseData; use daemon::{DaemonClient, DaemonCommand, DaemonStartOptions}; use tun::TunOptions; +use crate::daemon::DaemonResponseData; + #[derive(Parser)] #[command(name = "Burrow")] #[command(author = "Hack Club ")] @@ -88,8 +85,7 @@ async fn try_retrieve() -> Result<()> { } } - burrow::ensureroot::ensure_root(); - let iface2 = retrieve().ok_or(anyhow::anyhow!("No interface found"))?; + let iface2 = TunInterface::retrieve().ok_or(anyhow::anyhow!("No interface found"))?; tracing::info!("{:?}", iface2); Ok(()) } diff --git a/burrow/src/wireguard/iface.rs b/burrow/src/wireguard/iface.rs index 52f719b..9a0c216 100755 --- a/burrow/src/wireguard/iface.rs +++ b/burrow/src/wireguard/iface.rs @@ -1,24 +1,15 @@ -use std::sync::Arc; -use std::time::Duration; -use std::{net::IpAddr, rc::Rc}; +use std::{net::IpAddr, sync::Arc, time::Duration}; use anyhow::Error; use async_trait::async_trait; use fehler::throws; -use futures::future::join_all; -use futures::FutureExt; +use futures::{future::join_all, FutureExt}; use ip_network_table::IpNetworkTable; -use log::log; -use tokio::time::timeout; -use tokio::{ - join, - sync::{Mutex, RwLock}, - task::{self, JoinHandle}, -}; +use tokio::{sync::RwLock, task::JoinHandle, time::timeout}; use tracing::{debug, error}; use tun::tokio::TunInterface; -use super::{noise::Tunnel, pcb, Peer, PeerPcb}; +use super::{noise::Tunnel, Peer, PeerPcb}; #[async_trait] pub trait PacketInterface { @@ -122,12 +113,9 @@ impl Interface { Ok(Ok(len)) => &buf[..len], Ok(Err(e)) => { error!("failed to read from interface: {}", e); - continue; - } - Err(_would_block) => { - debug!("read timed out"); - continue; + continue } + Err(_would_block) => continue, }; debug!("read {} bytes from interface", src.len()); debug!("bytes: {:?}", src); @@ -138,7 +126,7 @@ impl Interface { Some(addr) => addr, None => { tracing::debug!("no destination found"); - continue; + continue } }; @@ -156,7 +144,7 @@ impl Interface { } Err(e) => { log::error!("failed to send packet {}", e); - continue; + continue } }; } @@ -175,20 +163,20 @@ impl Interface { let pcbs = &self.pcbs; for i in 0..pcbs.pcbs.len() { debug!("spawning read task for peer {}", i); - let mut pcb = pcbs.pcbs[i].clone(); + let pcb = pcbs.pcbs[i].clone(); let tun = tun.clone(); let tsk = async move { { let r1 = pcb.write().await.open_if_closed().await; if let Err(e) = r1 { log::error!("failed to open pcb: {}", e); - return; + return } } let r2 = pcb.read().await.run(tun).await; if let Err(e) = r2 { log::error!("failed to run pcb: {}", e); - return; + return } else { tracing::debug!("pcb ran successfully"); } diff --git a/burrow/src/wireguard/noise/handshake.rs b/burrow/src/wireguard/noise/handshake.rs index 3f8c91b..2ec0c6a 100755 --- a/burrow/src/wireguard/noise/handshake.rs +++ b/burrow/src/wireguard/noise/handshake.rs @@ -9,14 +9,20 @@ use std::{ use aead::{Aead, Payload}; use blake2::{ digest::{FixedOutput, KeyInit}, - Blake2s256, Blake2sMac, Digest, + Blake2s256, + Blake2sMac, + Digest, }; use chacha20poly1305::XChaCha20Poly1305; use rand_core::OsRng; use ring::aead::{Aad, LessSafeKey, Nonce, UnboundKey, CHACHA20_POLY1305}; use super::{ - errors::WireGuardError, session::Session, x25519, HandshakeInit, HandshakeResponse, + errors::WireGuardError, + session::Session, + x25519, + HandshakeInit, + HandshakeResponse, PacketCookieReply, }; @@ -130,10 +136,6 @@ fn aead_chacha20_open( ) -> Result<(), WireGuardError> { let mut nonce: [u8; 12] = [0; 12]; nonce[4..].copy_from_slice(&counter.to_le_bytes()); - - tracing::debug!("TAG A"); - tracing::debug!("{:?};{:?};{:?};{:?};{}", key, data, aad, nonce, counter); - aead_chacha20_open_inner(buffer, key, nonce, data, aad) .map_err(|_| WireGuardError::InvalidAeadTag)?; Ok(()) @@ -207,7 +209,7 @@ impl Tai64N { /// Parse a timestamp from a 12 byte u8 slice fn parse(buf: &[u8; 12]) -> Result { if buf.len() < 12 { - return Err(WireGuardError::InvalidTai64nTimestamp); + return Err(WireGuardError::InvalidTai64nTimestamp) } let (sec_bytes, nano_bytes) = buf.split_at(std::mem::size_of::()); @@ -554,22 +556,19 @@ impl Handshake { let timestamp = Tai64N::parse(×tamp)?; if !timestamp.after(&self.last_handshake_timestamp) { // Possibly a replay - return Err(WireGuardError::WrongTai64nTimestamp); + return Err(WireGuardError::WrongTai64nTimestamp) } self.last_handshake_timestamp = timestamp; // initiator.hash = HASH(initiator.hash || msg.encrypted_timestamp) hash = b2s_hash(&hash, packet.encrypted_timestamp); - self.previous = std::mem::replace( - &mut self.state, - HandshakeState::InitReceived { - chaining_key, - hash, - peer_ephemeral_public, - peer_index, - }, - ); + self.previous = std::mem::replace(&mut self.state, HandshakeState::InitReceived { + chaining_key, + hash, + peer_ephemeral_public, + peer_index, + }); self.format_handshake_response(dst) } @@ -670,7 +669,7 @@ impl Handshake { let local_index = self.cookies.index; if packet.receiver_idx != local_index { - return Err(WireGuardError::WrongIndex); + return Err(WireGuardError::WrongIndex) } // msg.encrypted_cookie = XAEAD(HASH(LABEL_COOKIE || responder.static_public), // msg.nonce, cookie, last_received_msg.mac1) @@ -680,7 +679,6 @@ impl Handshake { aad: &mac1[0..16], msg: packet.encrypted_cookie, }; - tracing::debug!("TAG B"); let plaintext = XChaCha20Poly1305::new_from_slice(&key) .unwrap() .decrypt(packet.nonce.into(), payload) @@ -727,7 +725,7 @@ impl Handshake { dst: &'a mut [u8], ) -> Result<&'a mut [u8], WireGuardError> { if dst.len() < super::HANDSHAKE_INIT_SZ { - return Err(WireGuardError::DestinationBufferTooSmall); + return Err(WireGuardError::DestinationBufferTooSmall) } let (message_type, rest) = dst.split_at_mut(4); @@ -810,7 +808,7 @@ impl Handshake { dst: &'a mut [u8], ) -> Result<(&'a mut [u8], Session), WireGuardError> { if dst.len() < super::HANDSHAKE_RESP_SZ { - return Err(WireGuardError::DestinationBufferTooSmall); + return Err(WireGuardError::DestinationBufferTooSmall) } let state = std::mem::replace(&mut self.state, HandshakeState::None); diff --git a/burrow/src/wireguard/noise/mod.rs b/burrow/src/wireguard/noise/mod.rs index 824d7c1..3a60c22 100755 --- a/burrow/src/wireguard/noise/mod.rs +++ b/burrow/src/wireguard/noise/mod.rs @@ -45,7 +45,11 @@ const N_SESSIONS: usize = 8; pub mod x25519 { pub use x25519_dalek::{ - EphemeralSecret, PublicKey, ReusableSecret, SharedSecret, StaticSecret, + EphemeralSecret, + PublicKey, + ReusableSecret, + SharedSecret, + StaticSecret, }; } @@ -137,7 +141,7 @@ impl Tunnel { #[inline(always)] pub fn parse_incoming_packet(src: &[u8]) -> Result { if src.len() < 4 { - return Err(WireGuardError::InvalidPacket); + return Err(WireGuardError::InvalidPacket) } // Checks the type, as well as the reserved zero fields @@ -179,7 +183,7 @@ impl Tunnel { pub fn dst_address(packet: &[u8]) -> Option { if packet.is_empty() { - return None; + return None } match packet[0] >> 4 { @@ -274,7 +278,7 @@ impl Tunnel { self.timer_tick(TimerName::TimeLastDataPacketSent); } self.tx_bytes += src.len(); - return TunnResult::WriteToNetwork(packet); + return TunnResult::WriteToNetwork(packet) } // If there is no session, queue the packet for future retry @@ -298,7 +302,7 @@ impl Tunnel { ) -> TunnResult<'a> { if datagram.is_empty() { // Indicates a repeated call - return self.send_queued_packet(dst); + return self.send_queued_packet(dst) } let mut cookie = [0u8; COOKIE_REPLY_SZ]; @@ -309,7 +313,7 @@ impl Tunnel { Ok(packet) => packet, Err(TunnResult::WriteToNetwork(cookie)) => { dst[..cookie.len()].copy_from_slice(cookie); - return TunnResult::WriteToNetwork(&mut dst[..cookie.len()]); + return TunnResult::WriteToNetwork(&mut dst[..cookie.len()]) } Err(TunnResult::Err(e)) => return TunnResult::Err(e), _ => unreachable!(), @@ -409,7 +413,7 @@ impl Tunnel { let cur_idx = self.current; if cur_idx == new_idx { // There is nothing to do, already using this session, this is the common case - return; + return } if self.sessions[cur_idx % N_SESSIONS].is_none() || self.timers.session_timers[new_idx % N_SESSIONS] @@ -455,7 +459,7 @@ impl Tunnel { force_resend: bool, ) -> TunnResult<'a> { if self.handshake.is_in_progress() && !force_resend { - return TunnResult::Done; + return TunnResult::Done } if self.handshake.is_expired() { @@ -514,7 +518,7 @@ impl Tunnel { }; if computed_len > packet.len() { - return TunnResult::Err(WireGuardError::InvalidPacket); + return TunnResult::Err(WireGuardError::InvalidPacket) } self.timer_tick(TimerName::TimeLastDataPacketReceived); diff --git a/burrow/src/wireguard/noise/rate_limiter.rs b/burrow/src/wireguard/noise/rate_limiter.rs index 02887ee..ff19efd 100755 --- a/burrow/src/wireguard/noise/rate_limiter.rs +++ b/burrow/src/wireguard/noise/rate_limiter.rs @@ -6,16 +6,25 @@ use std::{ use aead::{generic_array::GenericArray, AeadInPlace, KeyInit}; use chacha20poly1305::{Key, XChaCha20Poly1305}; -use log::log; use parking_lot::Mutex; use rand_core::{OsRng, RngCore}; use ring::constant_time::verify_slices_are_equal; use super::{ handshake::{ - b2s_hash, b2s_keyed_mac_16, b2s_keyed_mac_16_2, b2s_mac_24, LABEL_COOKIE, LABEL_MAC1, + b2s_hash, + b2s_keyed_mac_16, + b2s_keyed_mac_16_2, + b2s_mac_24, + LABEL_COOKIE, + LABEL_MAC1, }, - HandshakeInit, HandshakeResponse, Packet, TunnResult, Tunnel, WireGuardError, + HandshakeInit, + HandshakeResponse, + Packet, + TunnResult, + Tunnel, + WireGuardError, }; const COOKIE_REFRESH: u64 = 128; // Use 128 and not 120 so the compiler can optimize out the division @@ -127,7 +136,7 @@ impl RateLimiter { dst: &'a mut [u8], ) -> Result<&'a mut [u8], WireGuardError> { if dst.len() < super::COOKIE_REPLY_SZ { - return Err(WireGuardError::DestinationBufferTooSmall); + return Err(WireGuardError::DestinationBufferTooSmall) } let (message_type, rest) = dst.split_at_mut(4); @@ -193,7 +202,7 @@ impl RateLimiter { let cookie_packet = self .format_cookie_reply(sender_idx, cookie, mac1, dst) .map_err(TunnResult::Err)?; - return Err(TunnResult::WriteToNetwork(cookie_packet)); + return Err(TunnResult::WriteToNetwork(cookie_packet)) } } } diff --git a/burrow/src/wireguard/noise/session.rs b/burrow/src/wireguard/noise/session.rs index 14c191b..8988728 100755 --- a/burrow/src/wireguard/noise/session.rs +++ b/burrow/src/wireguard/noise/session.rs @@ -88,11 +88,11 @@ impl ReceivingKeyCounterValidator { fn will_accept(&self, counter: u64) -> Result<(), WireGuardError> { if counter >= self.next { // As long as the counter is growing no replay took place for sure - return Ok(()); + return Ok(()) } if counter + N_BITS < self.next { // Drop if too far back - return Err(WireGuardError::InvalidCounter); + return Err(WireGuardError::InvalidCounter) } if !self.check_bit(counter) { Ok(()) @@ -107,22 +107,22 @@ impl ReceivingKeyCounterValidator { fn mark_did_receive(&mut self, counter: u64) -> Result<(), WireGuardError> { if counter + N_BITS < self.next { // Drop if too far back - return Err(WireGuardError::InvalidCounter); + return Err(WireGuardError::InvalidCounter) } if counter == self.next { // Usually the packets arrive in order, in that case we simply mark the bit and // increment the counter self.set_bit(counter); self.next += 1; - return Ok(()); + return Ok(()) } if counter < self.next { // A packet arrived out of order, check if it is valid, and mark if self.check_bit(counter) { - return Err(WireGuardError::InvalidCounter); + return Err(WireGuardError::InvalidCounter) } self.set_bit(counter); - return Ok(()); + return Ok(()) } // Packets where dropped, or maybe reordered, skip them and mark unused if counter - self.next >= N_BITS { @@ -247,7 +247,7 @@ impl Session { panic!("The destination buffer is too small"); } if packet.receiver_idx != self.receiving_index { - return Err(WireGuardError::WrongIndex); + return Err(WireGuardError::WrongIndex) } // Don't reuse counters, in case this is a replay attack we want to quickly // check the counter without running expensive decryption diff --git a/burrow/src/wireguard/noise/timers.rs b/burrow/src/wireguard/noise/timers.rs index f713e6f..1d0cf1f 100755 --- a/burrow/src/wireguard/noise/timers.rs +++ b/burrow/src/wireguard/noise/timers.rs @@ -190,7 +190,7 @@ impl Tunnel { { if self.handshake.is_expired() { - return TunnResult::Err(WireGuardError::ConnectionExpired); + return TunnResult::Err(WireGuardError::ConnectionExpired) } // Clear cookie after COOKIE_EXPIRATION_TIME @@ -206,7 +206,7 @@ impl Tunnel { tracing::error!("CONNECTION_EXPIRED(REJECT_AFTER_TIME * 3)"); self.handshake.set_expired(); self.clear_all(); - return TunnResult::Err(WireGuardError::ConnectionExpired); + return TunnResult::Err(WireGuardError::ConnectionExpired) } if let Some(time_init_sent) = self.handshake.timer() { @@ -219,7 +219,7 @@ impl Tunnel { tracing::error!("CONNECTION_EXPIRED(REKEY_ATTEMPT_TIME)"); self.handshake.set_expired(); self.clear_all(); - return TunnResult::Err(WireGuardError::ConnectionExpired); + return TunnResult::Err(WireGuardError::ConnectionExpired) } if time_init_sent.elapsed() >= REKEY_TIMEOUT { @@ -299,11 +299,11 @@ impl Tunnel { } if handshake_initiation_required { - return self.format_handshake_initiation(dst, true); + return self.format_handshake_initiation(dst, true) } if keepalive_required { - return self.encapsulate(&[], dst); + return self.encapsulate(&[], dst) } TunnResult::Done diff --git a/burrow/src/wireguard/pcb.rs b/burrow/src/wireguard/pcb.rs index 6acc8d8..d11e736 100755 --- a/burrow/src/wireguard/pcb.rs +++ b/burrow/src/wireguard/pcb.rs @@ -1,19 +1,11 @@ -use std::io; -use std::net::SocketAddr; -use std::rc::Rc; -use std::sync::Arc; -use std::time::Duration; +use std::{net::SocketAddr, sync::Arc}; use anyhow::{anyhow, Error}; use fehler::throws; use ip_network::IpNetwork; -use log::log; use rand::random; -use tokio::sync::{Mutex, RwLock}; -use tokio::time::timeout; -use tokio::{net::UdpSocket, task::JoinHandle}; +use tokio::{net::UdpSocket, sync::RwLock, task::JoinHandle}; use tun::tokio::TunInterface; -use uuid::uuid; use super::{ iface::PacketInterface, @@ -83,16 +75,14 @@ impl PeerPcb { tracing::debug!("start read loop {}", rid); loop { tracing::debug!("{}: waiting for packet", rid); - let Some(socket) = &self.socket else { - continue - }; + let Some(socket) = &self.socket else { continue }; let mut res_buf = [0; 1500]; // tracing::debug!("{} : waiting for readability on {:?}", rid, socket); let len = match socket.recv(&mut res_buf).await { Ok(l) => l, Err(e) => { log::error!("{}: error reading from socket: {:?}", rid, e); - continue; + continue } }; let mut res_dat = &res_buf[..len]; @@ -105,33 +95,31 @@ impl PeerPcb { .await .decapsulate(None, res_dat, &mut buf[..]) { - TunnResult::Done => { - break; - } + TunnResult::Done => break, TunnResult::Err(e) => { tracing::error!(message = "Decapsulate error", error = ?e); - break; + break } TunnResult::WriteToNetwork(packet) => { tracing::debug!("WriteToNetwork: {:?}", packet); socket.send(packet).await?; tracing::debug!("WriteToNetwork done"); res_dat = &[]; - continue; + continue } TunnResult::WriteToTunnelV4(packet, addr) => { tracing::debug!("WriteToTunnelV4: {:?}, {:?}", packet, addr); tun_interface.read().await.send(packet).await?; - break; + break } TunnResult::WriteToTunnelV6(packet, addr) => { tracing::debug!("WriteToTunnelV6: {:?}, {:?}", packet, addr); tun_interface.read().await.send(packet).await?; - break; + break } } } - return Ok(len); + return Ok(len) } } diff --git a/burrow/src/wireguard/peer.rs b/burrow/src/wireguard/peer.rs index 27c5399..131b0d4 100755 --- a/burrow/src/wireguard/peer.rs +++ b/burrow/src/wireguard/peer.rs @@ -1,7 +1,5 @@ use std::{fmt, net::SocketAddr}; -use anyhow::Error; -use fehler::throws; use ip_network::IpNetwork; use x25519_dalek::{PublicKey, StaticSecret}; diff --git a/tun/src/lib.rs b/tun/src/lib.rs index 64e17df..a1ca636 100644 --- a/tun/src/lib.rs +++ b/tun/src/lib.rs @@ -15,4 +15,4 @@ mod options; pub mod tokio; pub use options::TunOptions; -pub use os_imp::{retrieve, TunInterface, TunQueue}; +pub use os_imp::{TunInterface, TunQueue}; diff --git a/tun/src/options.rs b/tun/src/options.rs index aafdad2..7c414dc 100644 --- a/tun/src/options.rs +++ b/tun/src/options.rs @@ -16,8 +16,6 @@ pub struct TunOptions { pub no_pi: bool, /// (Linux) Avoid opening an existing persistant device. pub tun_excl: bool, - /// (MacOS) Whether to seek the first available utun device. - pub seek_utun: Option, /// (Linux) The IP address of the tun interface. pub address: Option, } diff --git a/tun/src/tokio/mod.rs b/tun/src/tokio/mod.rs index c901cba..525e4d7 100644 --- a/tun/src/tokio/mod.rs +++ b/tun/src/tokio/mod.rs @@ -39,11 +39,11 @@ impl TunInterface { }) { Ok(result) => { tracing::debug!("HORRAY"); - return result; + return result } Err(_would_block) => { tracing::debug!("WouldBlock"); - continue; + continue } } } diff --git a/tun/src/unix/apple/mod.rs b/tun/src/unix/apple/mod.rs index ca3ddc7..ab08505 100644 --- a/tun/src/unix/apple/mod.rs +++ b/tun/src/unix/apple/mod.rs @@ -1,24 +1,24 @@ use std::{ io::{Error, IoSlice}, - mem, + mem::{self, ManuallyDrop}, net::{Ipv4Addr, SocketAddrV4}, - os::fd::{AsRawFd, RawFd}, + os::fd::{AsRawFd, FromRawFd, RawFd}, }; use byteorder::{ByteOrder, NetworkEndian}; use fehler::throws; use libc::{c_char, iovec, writev, AF_INET, AF_INET6}; use socket2::{Domain, SockAddr, Socket, Type}; -use tracing::{self, debug, instrument}; +use tracing::{self, instrument}; -mod kern_control; +pub mod kern_control; pub mod sys; -use crate::retrieve; use kern_control::SysControlSocket; pub use super::queue::TunQueue; -use super::{ifname_to_string, string_to_ifname, TunOptions}; +use super::{ifname_to_string, string_to_ifname}; +use crate::TunOptions; #[derive(Debug)] pub struct TunInterface { @@ -35,18 +35,41 @@ impl TunInterface { #[throws] #[instrument] pub fn new_with_options(options: TunOptions) -> TunInterface { - debug!("Opening tun interface with options: {:?}", &options); - let ti = if let Some(n) = options.seek_utun { - retrieve().ok_or(Error::new(std::io::ErrorKind::NotFound, "No utun found"))? - } else { - TunInterface::connect(0)? - }; + let ti = TunInterface::connect(0)?; + ti.configure(options)?; + ti + } + + pub fn retrieve() -> Option { + (3..100) + .filter_map(|fd| unsafe { + let peer_addr = socket2::SockAddr::init(|storage, len| { + *len = mem::size_of::() as u32; + libc::getpeername(fd, storage as *mut _, len); + Ok(()) + }) + .map(|(_, addr)| (fd, addr)); + peer_addr.ok() + }) + .filter(|(_fd, addr)| { + let ctl_addr = unsafe { &*(addr.as_ptr() as *const libc::sockaddr_ctl) }; + addr.family() == libc::AF_SYSTEM as u8 + && ctl_addr.ss_sysaddr == libc::AF_SYS_CONTROL as u16 + }) + .map(|(fd, _)| { + let socket = unsafe { socket2::Socket::from_raw_fd(fd) }; + TunInterface { socket } + }) + .next() + } + + #[throws] + fn configure(&self, options: TunOptions) { if let Some(addr) = options.address { if let Ok(addr) = addr.parse() { - ti.set_ipv4_addr(addr)?; + self.set_ipv4_addr(addr)?; } } - ti } #[throws] diff --git a/tun/src/unix/mod.rs b/tun/src/unix/mod.rs index 72a8795..775ba1d 100644 --- a/tun/src/unix/mod.rs +++ b/tun/src/unix/mod.rs @@ -1,13 +1,9 @@ -use std::mem::size_of; use std::{ io::{Error, Read}, - mem, os::fd::{AsRawFd, FromRawFd, IntoRawFd, RawFd}, }; -use tracing::{debug, error, instrument}; - -use super::{syscall, TunOptions}; +use tracing::instrument; mod queue; @@ -19,13 +15,9 @@ mod imp; #[path = "linux/mod.rs"] mod imp; -use crate::os_imp::imp::sys; -use crate::os_imp::imp::sys::resolve_ctl_info; use fehler::throws; pub use imp::TunInterface; -use libc::{getpeername, sockaddr_ctl, sockaddr_storage, socklen_t, AF_SYSTEM, AF_SYS_CONTROL}; pub use queue::TunQueue; -use socket2::SockAddr; impl AsRawFd for TunInterface { fn as_raw_fd(&self) -> RawFd { @@ -82,35 +74,3 @@ pub fn string_to_ifname(name: &str) -> [libc::c_char; libc::IFNAMSIZ] { buf[..len].copy_from_slice(unsafe { &*(name.as_bytes() as *const _ as *const [libc::c_char]) }); buf } - -#[cfg(any(target_os = "linux", target_vendor = "apple"))] -pub fn retrieve() -> Option { - (3..100) - .filter_map(|i| { - let result = unsafe { - let mut addr = sockaddr_ctl { - sc_len: size_of::() as u8, - sc_family: 0, - ss_sysaddr: 0, - sc_id: 0, - sc_unit: 0, - sc_reserved: Default::default(), - }; - let mut len = mem::size_of::() as libc::socklen_t; - let res = syscall!(getpeername(i, &mut addr as *mut _ as *mut _, len as *mut _)); - tracing::debug!("getpeername{}: {:?}", i, res); - if res.is_err() { - return None; - } - if addr.sc_family == sys::AF_SYSTEM as u8 - && addr.ss_sysaddr == sys::AF_SYS_CONTROL as u16 - { - Some(TunInterface::from_raw_fd(i)) - } else { - None - } - }; - result - }) - .next() -} From 4408e9aca8bc05cf508d5374e00020348acb326b Mon Sep 17 00:00:00 2001 From: Conrad Kramer Date: Sat, 9 Dec 2023 19:47:41 -0800 Subject: [PATCH 10/10] Update locking to be interior to PeerPcb --- burrow/src/daemon/instance.rs | 7 +++-- burrow/src/daemon/mod.rs | 3 +- burrow/src/wireguard/iface.rs | 24 +++++++-------- burrow/src/wireguard/pcb.rs | 56 +++++++++++++++-------------------- tun/src/tokio/mod.rs | 28 ++++-------------- tun/src/unix/mod.rs | 16 ++++++---- 6 files changed, 57 insertions(+), 77 deletions(-) diff --git a/burrow/src/daemon/instance.rs b/burrow/src/daemon/instance.rs index 98052d2..6a430c5 100644 --- a/burrow/src/daemon/instance.rs +++ b/burrow/src/daemon/instance.rs @@ -54,11 +54,12 @@ impl DaemonInstance { warn!("Got start, but tun interface already up."); } RunState::Idle => { - debug!("Creating new TunInterface"); - let retrieved = TunInterface::new(tun::TunInterface::retrieve().unwrap())?; + let raw = tun::TunInterface::retrieve().unwrap(); + debug!("TunInterface retrieved: {:?}", raw.name()?); + + let retrieved = TunInterface::new(raw)?; let tun_if = Arc::new(RwLock::new(retrieved)); // let tun_if = Arc::new(RwLock::new(TunInterface::new(st.tun.open()?)?)); - debug!("TunInterface created: {:?}", tun_if); debug!("Setting tun_interface"); self.tun_interface = Some(tun_if.clone()); diff --git a/burrow/src/daemon/mod.rs b/burrow/src/daemon/mod.rs index 394ebec..1020cf7 100644 --- a/burrow/src/daemon/mod.rs +++ b/burrow/src/daemon/mod.rs @@ -62,7 +62,8 @@ pub async fn daemon_main() -> Result<()> { allowed_ips: vec![IpNetwork::V4(Ipv4Network::DEFAULT_ROUTE)], }])?; - let mut inst = DaemonInstance::new(commands_rx, response_tx, Arc::new(RwLock::new(iface))); + let mut inst: DaemonInstance = + DaemonInstance::new(commands_rx, response_tx, Arc::new(RwLock::new(iface))); tracing::info!("Starting daemon jobs..."); diff --git a/burrow/src/wireguard/iface.rs b/burrow/src/wireguard/iface.rs index 9a0c216..3d1823b 100755 --- a/burrow/src/wireguard/iface.rs +++ b/burrow/src/wireguard/iface.rs @@ -29,7 +29,7 @@ impl PacketInterface for tun::tokio::TunInterface { } struct IndexedPcbs { - pcbs: Vec>>, + pcbs: Vec>, allowed_ips: IpNetworkTable, } @@ -46,7 +46,7 @@ impl IndexedPcbs { for allowed_ip in pcb.allowed_ips.iter() { self.allowed_ips.insert(allowed_ip.clone(), idx); } - self.pcbs.insert(idx, Arc::new(RwLock::new(pcb))); + self.pcbs.insert(idx, Arc::new(pcb)); } pub fn find(&self, addr: IpAddr) -> Option { @@ -55,7 +55,7 @@ impl IndexedPcbs { } pub async fn connect(&self, idx: usize, handle: JoinHandle<()>) { - self.pcbs[idx].write().await.handle = Some(handle); + self.pcbs[idx].handle.write().await.replace(handle); } } @@ -106,7 +106,7 @@ impl Interface { let src = { let src = match timeout( Duration::from_millis(10), - tun.write().await.recv(&mut buf[..]), + tun.read().await.recv(&mut buf[..]), ) .await { @@ -138,9 +138,10 @@ impl Interface { tracing::debug!("found peer:{}", idx); - match pcbs.pcbs[idx].read().await.send(src).await { + match pcbs.pcbs[idx].send(src).await { Ok(..) => { - tracing::debug!("sent packet to peer {}", dst_addr); + let addr = pcbs.pcbs[idx].endpoint; + tracing::debug!("sent packet to peer {}", addr); } Err(e) => { log::error!("failed to send packet {}", e); @@ -166,14 +167,11 @@ impl Interface { let pcb = pcbs.pcbs[i].clone(); let tun = tun.clone(); let tsk = async move { - { - let r1 = pcb.write().await.open_if_closed().await; - if let Err(e) = r1 { - log::error!("failed to open pcb: {}", e); - return - } + if let Err(e) = pcb.open_if_closed().await { + log::error!("failed to open pcb: {}", e); + return } - let r2 = pcb.read().await.run(tun).await; + let r2 = pcb.run(tun).await; if let Err(e) = r2 { log::error!("failed to run pcb: {}", e); return diff --git a/burrow/src/wireguard/pcb.rs b/burrow/src/wireguard/pcb.rs index d11e736..21b1d6e 100755 --- a/burrow/src/wireguard/pcb.rs +++ b/burrow/src/wireguard/pcb.rs @@ -1,4 +1,8 @@ -use std::{net::SocketAddr, sync::Arc}; +use std::{ + cell::{Cell, RefCell}, + net::SocketAddr, + sync::Arc, +}; use anyhow::{anyhow, Error}; use fehler::throws; @@ -8,7 +12,6 @@ use tokio::{net::UdpSocket, sync::RwLock, task::JoinHandle}; use tun::tokio::TunInterface; use super::{ - iface::PacketInterface, noise::{TunnResult, Tunnel}, Peer, }; @@ -17,8 +20,8 @@ use super::{ pub struct PeerPcb { pub endpoint: SocketAddr, pub allowed_ips: Vec, - pub handle: Option>, - socket: Option, + pub handle: RwLock>>, + socket: RwLock>, tunnel: RwLock, } @@ -36,46 +39,35 @@ impl PeerPcb { ) .map_err(|s| anyhow::anyhow!("{}", s))?, ); - Self { endpoint: peer.endpoint, allowed_ips: peer.allowed_ips, - handle: None, - socket: None, + handle: RwLock::new(None), + socket: RwLock::new(None), tunnel, } } - pub async fn open_if_closed(&mut self) -> Result<(), Error> { - if self.socket.is_none() { + pub async fn open_if_closed(&self) -> Result<(), Error> { + if self.socket.read().await.is_none() { let socket = UdpSocket::bind("0.0.0.0:0").await?; socket.connect(self.endpoint).await?; - self.socket = Some(socket); + self.socket.write().await.replace(socket); } Ok(()) } pub async fn run(&self, tun_interface: Arc>) -> Result<(), Error> { - let mut buf = [0u8; 3000]; - tracing::debug!("starting read loop for pcb..."); - loop { - tracing::debug!("waiting for packet"); - let len = self.recv(&mut buf, tun_interface.clone()).await?; - tracing::debug!("received {} bytes", len); - } - } - - pub async fn recv( - &self, - buf: &mut [u8], - tun_interface: Arc>, - ) -> Result { tracing::debug!("starting read loop for pcb... for {:?}", &self); let rid: i32 = random(); + let mut buf: [u8; 3000] = [0u8; 3000]; tracing::debug!("start read loop {}", rid); loop { tracing::debug!("{}: waiting for packet", rid); - let Some(socket) = &self.socket else { continue }; + let guard = self.socket.read().await; + let Some(socket) = guard.as_ref() else { + continue + }; let mut res_buf = [0; 1500]; // tracing::debug!("{} : waiting for readability on {:?}", rid, socket); let len = match socket.recv(&mut res_buf).await { @@ -102,6 +94,7 @@ impl PeerPcb { } TunnResult::WriteToNetwork(packet) => { tracing::debug!("WriteToNetwork: {:?}", packet); + self.open_if_closed().await?; socket.send(packet).await?; tracing::debug!("WriteToNetwork done"); res_dat = &[]; @@ -119,15 +112,9 @@ impl PeerPcb { } } } - return Ok(len) } } - pub async fn socket(&mut self) -> Result<&UdpSocket, Error> { - self.open_if_closed().await?; - Ok(self.socket.as_ref().expect("socket was just opened")) - } - pub async fn send(&self, src: &[u8]) -> Result<(), Error> { let mut dst_buf = [0u8; 3000]; match self.tunnel.write().await.encapsulate(src, &mut dst_buf[..]) { @@ -136,7 +123,12 @@ impl PeerPcb { tracing::error!(message = "Encapsulate error", error = ?e) } TunnResult::WriteToNetwork(packet) => { - let socket = &self.socket.as_ref().ok_or(anyhow!("socket not open"))?; + self.open_if_closed().await?; + let handle = self.socket.read().await; + let Some(socket) = handle.as_ref() else { + tracing::error!("No socket for peer"); + return Ok(()) + }; tracing::debug!("Our Encapsulated packet: {:?}", packet); socket.send(packet).await?; } diff --git a/tun/src/tokio/mod.rs b/tun/src/tokio/mod.rs index 525e4d7..947fb74 100644 --- a/tun/src/tokio/mod.rs +++ b/tun/src/tokio/mod.rs @@ -26,21 +26,12 @@ impl TunInterface { } } - // #[instrument] - pub async fn recv(&mut self, buf: &mut [u8]) -> io::Result { + #[instrument] + pub async fn recv(&self, buf: &mut [u8]) -> io::Result { loop { - // tracing::debug!("TunInterface receiving..."); - let mut guard = self.inner.readable_mut().await?; - // tracing::debug!("Got! readable_mut"); - match guard.try_io(|inner| { - let raw_ref = (*inner).get_mut(); - let recved = raw_ref.recv(buf); - recved - }) { - Ok(result) => { - tracing::debug!("HORRAY"); - return result - } + let mut guard = self.inner.readable().await?; + match guard.try_io(|inner| inner.get_ref().recv(buf)) { + Ok(result) => return result, Err(_would_block) => { tracing::debug!("WouldBlock"); continue @@ -48,13 +39,4 @@ impl TunInterface { } } } - - #[instrument] - pub async fn try_recv(&mut self, buf: &mut [u8]) -> io::Result { - let mut guard = self.inner.readable_mut().await?; - match guard.try_io(|inner| (*inner).get_mut().recv(buf)) { - Ok(result) => Ok(result.unwrap_or_default()), - Err(_would_block) => Err(io::Error::new(io::ErrorKind::WouldBlock, "WouldBlock")), - } - } } diff --git a/tun/src/unix/mod.rs b/tun/src/unix/mod.rs index 775ba1d..77a1158 100644 --- a/tun/src/unix/mod.rs +++ b/tun/src/unix/mod.rs @@ -1,5 +1,6 @@ use std::{ io::{Error, Read}, + mem::MaybeUninit, os::fd::{AsRawFd, FromRawFd, IntoRawFd, RawFd}, }; @@ -38,14 +39,19 @@ impl IntoRawFd for TunInterface { } } +unsafe fn assume_init(buf: &[MaybeUninit]) -> &[u8] { + &*(buf as *const [MaybeUninit] as *const [u8]) +} + impl TunInterface { #[throws] #[instrument] - pub fn recv(&mut self, buf: &mut [u8]) -> usize { - // there might be a more efficient way to implement this - let tmp_buf = &mut [0u8; 1500]; - let len = self.socket.read(tmp_buf)?; - buf[..len - 4].copy_from_slice(&tmp_buf[4..len]); + pub fn recv(&self, buf: &mut [u8]) -> usize { + // Use IoVec to read directly into target buffer + let mut tmp_buf = [MaybeUninit::uninit(); 1500]; + let len = self.socket.recv(&mut tmp_buf)?; + let result_buf = unsafe { assume_init(&tmp_buf[4..len]) }; + buf[..len - 4].copy_from_slice(&result_buf); len - 4 }