From 3e5a01ffbe921a9209f5277222e7dd75cab4d2ac Mon Sep 17 00:00:00 2001 From: Conrad Kramer Date: Sat, 9 Dec 2023 19:47:41 -0800 Subject: [PATCH] Update locking to be interior to PeerPcb --- burrow/src/daemon/instance.rs | 7 +++-- burrow/src/daemon/mod.rs | 3 +- burrow/src/wireguard/iface.rs | 24 +++++++-------- burrow/src/wireguard/pcb.rs | 56 +++++++++++++++-------------------- tun/src/tokio/mod.rs | 28 ++++-------------- tun/src/unix/mod.rs | 16 ++++++---- 6 files changed, 57 insertions(+), 77 deletions(-) diff --git a/burrow/src/daemon/instance.rs b/burrow/src/daemon/instance.rs index 98052d2..6a430c5 100644 --- a/burrow/src/daemon/instance.rs +++ b/burrow/src/daemon/instance.rs @@ -54,11 +54,12 @@ impl DaemonInstance { warn!("Got start, but tun interface already up."); } RunState::Idle => { - debug!("Creating new TunInterface"); - let retrieved = TunInterface::new(tun::TunInterface::retrieve().unwrap())?; + let raw = tun::TunInterface::retrieve().unwrap(); + debug!("TunInterface retrieved: {:?}", raw.name()?); + + let retrieved = TunInterface::new(raw)?; let tun_if = Arc::new(RwLock::new(retrieved)); // let tun_if = Arc::new(RwLock::new(TunInterface::new(st.tun.open()?)?)); - debug!("TunInterface created: {:?}", tun_if); debug!("Setting tun_interface"); self.tun_interface = Some(tun_if.clone()); diff --git a/burrow/src/daemon/mod.rs b/burrow/src/daemon/mod.rs index 394ebec..1020cf7 100644 --- a/burrow/src/daemon/mod.rs +++ b/burrow/src/daemon/mod.rs @@ -62,7 +62,8 @@ pub async fn daemon_main() -> Result<()> { allowed_ips: vec![IpNetwork::V4(Ipv4Network::DEFAULT_ROUTE)], }])?; - let mut inst = DaemonInstance::new(commands_rx, response_tx, Arc::new(RwLock::new(iface))); + let mut inst: DaemonInstance = + DaemonInstance::new(commands_rx, response_tx, Arc::new(RwLock::new(iface))); tracing::info!("Starting daemon jobs..."); diff --git a/burrow/src/wireguard/iface.rs b/burrow/src/wireguard/iface.rs index 9a0c216..3d1823b 100755 --- a/burrow/src/wireguard/iface.rs +++ b/burrow/src/wireguard/iface.rs @@ -29,7 +29,7 @@ impl PacketInterface for tun::tokio::TunInterface { } struct IndexedPcbs { - pcbs: Vec>>, + pcbs: Vec>, allowed_ips: IpNetworkTable, } @@ -46,7 +46,7 @@ impl IndexedPcbs { for allowed_ip in pcb.allowed_ips.iter() { self.allowed_ips.insert(allowed_ip.clone(), idx); } - self.pcbs.insert(idx, Arc::new(RwLock::new(pcb))); + self.pcbs.insert(idx, Arc::new(pcb)); } pub fn find(&self, addr: IpAddr) -> Option { @@ -55,7 +55,7 @@ impl IndexedPcbs { } pub async fn connect(&self, idx: usize, handle: JoinHandle<()>) { - self.pcbs[idx].write().await.handle = Some(handle); + self.pcbs[idx].handle.write().await.replace(handle); } } @@ -106,7 +106,7 @@ impl Interface { let src = { let src = match timeout( Duration::from_millis(10), - tun.write().await.recv(&mut buf[..]), + tun.read().await.recv(&mut buf[..]), ) .await { @@ -138,9 +138,10 @@ impl Interface { tracing::debug!("found peer:{}", idx); - match pcbs.pcbs[idx].read().await.send(src).await { + match pcbs.pcbs[idx].send(src).await { Ok(..) => { - tracing::debug!("sent packet to peer {}", dst_addr); + let addr = pcbs.pcbs[idx].endpoint; + tracing::debug!("sent packet to peer {}", addr); } Err(e) => { log::error!("failed to send packet {}", e); @@ -166,14 +167,11 @@ impl Interface { let pcb = pcbs.pcbs[i].clone(); let tun = tun.clone(); let tsk = async move { - { - let r1 = pcb.write().await.open_if_closed().await; - if let Err(e) = r1 { - log::error!("failed to open pcb: {}", e); - return - } + if let Err(e) = pcb.open_if_closed().await { + log::error!("failed to open pcb: {}", e); + return } - let r2 = pcb.read().await.run(tun).await; + let r2 = pcb.run(tun).await; if let Err(e) = r2 { log::error!("failed to run pcb: {}", e); return diff --git a/burrow/src/wireguard/pcb.rs b/burrow/src/wireguard/pcb.rs index d11e736..21b1d6e 100755 --- a/burrow/src/wireguard/pcb.rs +++ b/burrow/src/wireguard/pcb.rs @@ -1,4 +1,8 @@ -use std::{net::SocketAddr, sync::Arc}; +use std::{ + cell::{Cell, RefCell}, + net::SocketAddr, + sync::Arc, +}; use anyhow::{anyhow, Error}; use fehler::throws; @@ -8,7 +12,6 @@ use tokio::{net::UdpSocket, sync::RwLock, task::JoinHandle}; use tun::tokio::TunInterface; use super::{ - iface::PacketInterface, noise::{TunnResult, Tunnel}, Peer, }; @@ -17,8 +20,8 @@ use super::{ pub struct PeerPcb { pub endpoint: SocketAddr, pub allowed_ips: Vec, - pub handle: Option>, - socket: Option, + pub handle: RwLock>>, + socket: RwLock>, tunnel: RwLock, } @@ -36,46 +39,35 @@ impl PeerPcb { ) .map_err(|s| anyhow::anyhow!("{}", s))?, ); - Self { endpoint: peer.endpoint, allowed_ips: peer.allowed_ips, - handle: None, - socket: None, + handle: RwLock::new(None), + socket: RwLock::new(None), tunnel, } } - pub async fn open_if_closed(&mut self) -> Result<(), Error> { - if self.socket.is_none() { + pub async fn open_if_closed(&self) -> Result<(), Error> { + if self.socket.read().await.is_none() { let socket = UdpSocket::bind("0.0.0.0:0").await?; socket.connect(self.endpoint).await?; - self.socket = Some(socket); + self.socket.write().await.replace(socket); } Ok(()) } pub async fn run(&self, tun_interface: Arc>) -> Result<(), Error> { - let mut buf = [0u8; 3000]; - tracing::debug!("starting read loop for pcb..."); - loop { - tracing::debug!("waiting for packet"); - let len = self.recv(&mut buf, tun_interface.clone()).await?; - tracing::debug!("received {} bytes", len); - } - } - - pub async fn recv( - &self, - buf: &mut [u8], - tun_interface: Arc>, - ) -> Result { tracing::debug!("starting read loop for pcb... for {:?}", &self); let rid: i32 = random(); + let mut buf: [u8; 3000] = [0u8; 3000]; tracing::debug!("start read loop {}", rid); loop { tracing::debug!("{}: waiting for packet", rid); - let Some(socket) = &self.socket else { continue }; + let guard = self.socket.read().await; + let Some(socket) = guard.as_ref() else { + continue + }; let mut res_buf = [0; 1500]; // tracing::debug!("{} : waiting for readability on {:?}", rid, socket); let len = match socket.recv(&mut res_buf).await { @@ -102,6 +94,7 @@ impl PeerPcb { } TunnResult::WriteToNetwork(packet) => { tracing::debug!("WriteToNetwork: {:?}", packet); + self.open_if_closed().await?; socket.send(packet).await?; tracing::debug!("WriteToNetwork done"); res_dat = &[]; @@ -119,15 +112,9 @@ impl PeerPcb { } } } - return Ok(len) } } - pub async fn socket(&mut self) -> Result<&UdpSocket, Error> { - self.open_if_closed().await?; - Ok(self.socket.as_ref().expect("socket was just opened")) - } - pub async fn send(&self, src: &[u8]) -> Result<(), Error> { let mut dst_buf = [0u8; 3000]; match self.tunnel.write().await.encapsulate(src, &mut dst_buf[..]) { @@ -136,7 +123,12 @@ impl PeerPcb { tracing::error!(message = "Encapsulate error", error = ?e) } TunnResult::WriteToNetwork(packet) => { - let socket = &self.socket.as_ref().ok_or(anyhow!("socket not open"))?; + self.open_if_closed().await?; + let handle = self.socket.read().await; + let Some(socket) = handle.as_ref() else { + tracing::error!("No socket for peer"); + return Ok(()) + }; tracing::debug!("Our Encapsulated packet: {:?}", packet); socket.send(packet).await?; } diff --git a/tun/src/tokio/mod.rs b/tun/src/tokio/mod.rs index 525e4d7..947fb74 100644 --- a/tun/src/tokio/mod.rs +++ b/tun/src/tokio/mod.rs @@ -26,21 +26,12 @@ impl TunInterface { } } - // #[instrument] - pub async fn recv(&mut self, buf: &mut [u8]) -> io::Result { + #[instrument] + pub async fn recv(&self, buf: &mut [u8]) -> io::Result { loop { - // tracing::debug!("TunInterface receiving..."); - let mut guard = self.inner.readable_mut().await?; - // tracing::debug!("Got! readable_mut"); - match guard.try_io(|inner| { - let raw_ref = (*inner).get_mut(); - let recved = raw_ref.recv(buf); - recved - }) { - Ok(result) => { - tracing::debug!("HORRAY"); - return result - } + let mut guard = self.inner.readable().await?; + match guard.try_io(|inner| inner.get_ref().recv(buf)) { + Ok(result) => return result, Err(_would_block) => { tracing::debug!("WouldBlock"); continue @@ -48,13 +39,4 @@ impl TunInterface { } } } - - #[instrument] - pub async fn try_recv(&mut self, buf: &mut [u8]) -> io::Result { - let mut guard = self.inner.readable_mut().await?; - match guard.try_io(|inner| (*inner).get_mut().recv(buf)) { - Ok(result) => Ok(result.unwrap_or_default()), - Err(_would_block) => Err(io::Error::new(io::ErrorKind::WouldBlock, "WouldBlock")), - } - } } diff --git a/tun/src/unix/mod.rs b/tun/src/unix/mod.rs index 775ba1d..77a1158 100644 --- a/tun/src/unix/mod.rs +++ b/tun/src/unix/mod.rs @@ -1,5 +1,6 @@ use std::{ io::{Error, Read}, + mem::MaybeUninit, os::fd::{AsRawFd, FromRawFd, IntoRawFd, RawFd}, }; @@ -38,14 +39,19 @@ impl IntoRawFd for TunInterface { } } +unsafe fn assume_init(buf: &[MaybeUninit]) -> &[u8] { + &*(buf as *const [MaybeUninit] as *const [u8]) +} + impl TunInterface { #[throws] #[instrument] - pub fn recv(&mut self, buf: &mut [u8]) -> usize { - // there might be a more efficient way to implement this - let tmp_buf = &mut [0u8; 1500]; - let len = self.socket.read(tmp_buf)?; - buf[..len - 4].copy_from_slice(&tmp_buf[4..len]); + pub fn recv(&self, buf: &mut [u8]) -> usize { + // Use IoVec to read directly into target buffer + let mut tmp_buf = [MaybeUninit::uninit(); 1500]; + let len = self.socket.recv(&mut tmp_buf)?; + let result_buf = unsafe { assume_init(&tmp_buf[4..len]) }; + buf[..len - 4].copy_from_slice(&result_buf); len - 4 }