From 6c1c806401e105df30cd96668939c1992f5710d5 Mon Sep 17 00:00:00 2001 From: Jett Chen Date: Mon, 27 Nov 2023 12:19:14 +0800 Subject: [PATCH] Switch to RwLock --- burrow/src/wireguard/iface.rs | 58 ++++++++++++++++++----------------- burrow/src/wireguard/pcb.rs | 41 ++++++++++++------------- 2 files changed, 49 insertions(+), 50 deletions(-) diff --git a/burrow/src/wireguard/iface.rs b/burrow/src/wireguard/iface.rs index 2373cc3..7e260d6 100755 --- a/burrow/src/wireguard/iface.rs +++ b/burrow/src/wireguard/iface.rs @@ -8,11 +8,12 @@ use ip_network_table::IpNetworkTable; use log::log; use tokio::{ join, - sync::Mutex, + sync::{Mutex, RwLock}, task::{self, JoinHandle}, }; use tun::tokio::TunInterface; use futures::future::join_all; +use futures::FutureExt; use super::{noise::Tunnel, pcb, Peer, PeerPcb}; @@ -34,7 +35,7 @@ impl PacketInterface for tun::tokio::TunInterface { } struct IndexedPcbs { - pcbs: Vec>>, + pcbs: Vec>>, allowed_ips: IpNetworkTable, } @@ -51,16 +52,16 @@ impl IndexedPcbs { for allowed_ip in pcb.allowed_ips.iter() { self.allowed_ips.insert(allowed_ip.clone(), idx); } - self.pcbs.insert(idx, Arc::new(Mutex::new(pcb))); + self.pcbs.insert(idx, Arc::new(RwLock::new(pcb))); } - pub fn find(&mut self, addr: IpAddr) -> Option { + pub fn find(&self, addr: IpAddr) -> Option { let (_, &idx) = self.allowed_ips.longest_match(addr)?; Some(idx) } - pub async fn connect(&mut self, idx: usize, handle: JoinHandle<()>) { - self.pcbs[idx].lock().await.handle = Some(handle); + pub async fn connect(&self, idx: usize, handle: JoinHandle<()>) { + self.pcbs[idx].write().await.handle = Some(handle); } } @@ -74,20 +75,20 @@ impl FromIterator for IndexedPcbs { } pub struct Interface { - tun: Arc>, - pcbs: Arc>, + tun: Arc>, + pcbs: Arc, } impl Interface { #[throws] pub fn new>(tun: TunInterface, peers: I) -> Self { - let pcbs: IndexedPcbs = peers + let mut pcbs: IndexedPcbs = peers .into_iter() .map(|peer| PeerPcb::new(peer)) .collect::>()?; - let tun = Arc::new(Mutex::new(tun)); - let pcbs = Arc::new(Mutex::new(pcbs)); + let tun = Arc::new(RwLock::new(tun)); + let pcbs = Arc::new(pcbs); Self { tun, pcbs } } @@ -101,19 +102,19 @@ impl Interface { log::debug!("starting loop..."); let mut buf = [0u8; 3000]; - let mut tun = tun.lock().await; - log::debug!("awaiting read..."); - let src = match tun.recv(&mut buf[..]).await { - Ok(len) => &buf[..len], - Err(e) => { - log::error!("failed reading from interface: {}", e); - continue - } + let src = { + log::debug!("awaiting read..."); + let src = match tun.write().await.recv(&mut buf[..]).await { + Ok(len) => &buf[..len], + Err(e) => { + log::error!("failed reading from interface: {}", e); + continue + } + }; + log::debug!("read {} bytes from interface", src.len()); + log::debug!("bytes: {:?}", src); + src }; - log::debug!("read {} bytes from interface", src.len()); - log::debug!("bytes: {:?}", src); - - let mut pcbs = pcbs.lock().await; let dst_addr = match Tunnel::dst_address(src) { Some(addr) => addr, @@ -131,7 +132,7 @@ impl Interface { log::debug!("found peer:{}", idx); - match pcbs.pcbs[idx].lock().await.send(src).await { + match pcbs.pcbs[idx].read().await.send(src).await { Ok(..) => { log::debug!("sent packet to peer {}", dst_addr); } @@ -152,12 +153,13 @@ impl Interface { let outgoing = tokio::task::spawn(outgoing); tsks.push(outgoing); { - let pcbs = self.pcbs.lock().await; + let pcbs = self.pcbs; for i in 0..pcbs.pcbs.len(){ - let pcb = pcbs.pcbs[i].clone(); + let mut pcb = pcbs.pcbs[i].clone(); let tun = tun.clone(); - let tsk = async move{ - pcb.lock().await.run(tun).await.unwrap(); + let tsk = async move { + pcb.write().await.open_if_closed().await; + pcb.read().await.run(tun).await; }; tsks.push(tokio::task::spawn(tsk)); } diff --git a/burrow/src/wireguard/pcb.rs b/burrow/src/wireguard/pcb.rs index 151aaf8..051ca53 100755 --- a/burrow/src/wireguard/pcb.rs +++ b/burrow/src/wireguard/pcb.rs @@ -1,12 +1,13 @@ +use std::io; use std::net::SocketAddr; use std::rc::Rc; use std::sync::Arc; -use anyhow::Error; +use anyhow::{anyhow, Error}; use fehler::throws; use ip_network::IpNetwork; use tokio::{net::UdpSocket, task::JoinHandle}; -use tokio::sync::Mutex; +use tokio::sync::{Mutex, RwLock}; use super::{ iface::PacketInterface, @@ -20,14 +21,14 @@ pub struct PeerPcb { pub allowed_ips: Vec, pub handle: Option>, socket: Option, - tunnel: Tunnel, + tunnel: RwLock, } impl PeerPcb { #[throws] pub fn new(peer: Peer) -> Self { - let tunnel = Tunnel::new(peer.private_key, peer.public_key, peer.preshared_key, None, 1, None) - .map_err(|s| anyhow::anyhow!("{}", s))?; + let tunnel = RwLock::new(Tunnel::new(peer.private_key, peer.public_key, peer.preshared_key, None, 1, None) + .map_err(|s| anyhow::anyhow!("{}", s))?); Self { endpoint: peer.endpoint, @@ -38,7 +39,7 @@ impl PeerPcb { } } - async fn open_if_closed(&mut self) -> Result<(), Error> { + pub async fn open_if_closed(&mut self) -> Result<(), Error> { if self.socket.is_none() { let socket = UdpSocket::bind("0.0.0.0:0").await?; socket.connect(self.endpoint).await?; @@ -47,23 +48,18 @@ impl PeerPcb { Ok(()) } - pub async fn run(&self, interface: Arc>) -> Result<(), Error> { + pub async fn run(&self, interface: Arc>) -> Result<(), Error> { let mut buf = [0u8; 3000]; log::debug!("starting read loop for pcb..."); loop { - let Some(socket) = self.socket.as_ref() else { - continue + tracing::debug!("looping"); + + let sock = match &self.socket { + None => {continue} + Some(sock) => {sock} }; - let packet = match socket.recv(&mut buf).await { - Ok(s) => &buf[..s], - Err(e) => { - tracing::error!("eror receiving on peer socket: {}", e); - continue - } - }; - - let (len, addr) = socket.recv_from(&mut buf).await?; + let (len, addr) = sock.recv_from(&mut buf).await?; tracing::debug!("received {} bytes from {}", len, addr); } @@ -80,7 +76,7 @@ impl PeerPcb { tracing::debug!("Decapsulating {} bytes from {}", len, addr); tracing::debug!("{:?}", &res_dat); loop { - match self.tunnel.decapsulate(None, res_dat, &mut buf[..]) { + match self.tunnel.write().await.decapsulate(None, res_dat, &mut buf[..]) { TunnResult::Done => { tracing::debug!("Decapsulate done"); break; @@ -110,15 +106,16 @@ impl PeerPcb { Ok(self.socket.as_ref().expect("socket was just opened")) } - pub async fn send(&mut self, src: &[u8]) -> Result<(), Error> { + + pub async fn send(&self, src: &[u8]) -> Result<(), Error> { let mut dst_buf = [0u8; 3000]; - match self.tunnel.encapsulate(src, &mut dst_buf[..]) { + match self.tunnel.write().await.encapsulate(src, &mut dst_buf[..]) { TunnResult::Done => {} TunnResult::Err(e) => { tracing::error!(message = "Encapsulate error", error = ?e) } TunnResult::WriteToNetwork(packet) => { - let socket = self.socket().await?; + let socket = &self.socket.as_ref().ok_or(anyhow!("socket not open"))?; tracing::debug!("Our Encapsulated packet: {:?}", packet); socket.send(packet).await?; }