Switch to RwLock

This commit is contained in:
Jett Chen 2023-11-27 12:19:14 +08:00
parent 02f5a4ef74
commit 6c1c806401
2 changed files with 49 additions and 50 deletions

View file

@ -8,11 +8,12 @@ use ip_network_table::IpNetworkTable;
use log::log; use log::log;
use tokio::{ use tokio::{
join, join,
sync::Mutex, sync::{Mutex, RwLock},
task::{self, JoinHandle}, task::{self, JoinHandle},
}; };
use tun::tokio::TunInterface; use tun::tokio::TunInterface;
use futures::future::join_all; use futures::future::join_all;
use futures::FutureExt;
use super::{noise::Tunnel, pcb, Peer, PeerPcb}; use super::{noise::Tunnel, pcb, Peer, PeerPcb};
@ -34,7 +35,7 @@ impl PacketInterface for tun::tokio::TunInterface {
} }
struct IndexedPcbs { struct IndexedPcbs {
pcbs: Vec<Arc<Mutex<PeerPcb>>>, pcbs: Vec<Arc<RwLock<PeerPcb>>>,
allowed_ips: IpNetworkTable<usize>, allowed_ips: IpNetworkTable<usize>,
} }
@ -51,16 +52,16 @@ impl IndexedPcbs {
for allowed_ip in pcb.allowed_ips.iter() { for allowed_ip in pcb.allowed_ips.iter() {
self.allowed_ips.insert(allowed_ip.clone(), idx); self.allowed_ips.insert(allowed_ip.clone(), idx);
} }
self.pcbs.insert(idx, Arc::new(Mutex::new(pcb))); self.pcbs.insert(idx, Arc::new(RwLock::new(pcb)));
} }
pub fn find(&mut self, addr: IpAddr) -> Option<usize> { pub fn find(&self, addr: IpAddr) -> Option<usize> {
let (_, &idx) = self.allowed_ips.longest_match(addr)?; let (_, &idx) = self.allowed_ips.longest_match(addr)?;
Some(idx) Some(idx)
} }
pub async fn connect(&mut self, idx: usize, handle: JoinHandle<()>) { pub async fn connect(&self, idx: usize, handle: JoinHandle<()>) {
self.pcbs[idx].lock().await.handle = Some(handle); self.pcbs[idx].write().await.handle = Some(handle);
} }
} }
@ -74,20 +75,20 @@ impl FromIterator<PeerPcb> for IndexedPcbs {
} }
pub struct Interface { pub struct Interface {
tun: Arc<Mutex<TunInterface>>, tun: Arc<RwLock<TunInterface>>,
pcbs: Arc<Mutex<IndexedPcbs>>, pcbs: Arc<IndexedPcbs>,
} }
impl Interface { impl Interface {
#[throws] #[throws]
pub fn new<I: IntoIterator<Item = Peer>>(tun: TunInterface, peers: I) -> Self { pub fn new<I: IntoIterator<Item = Peer>>(tun: TunInterface, peers: I) -> Self {
let pcbs: IndexedPcbs = peers let mut pcbs: IndexedPcbs = peers
.into_iter() .into_iter()
.map(|peer| PeerPcb::new(peer)) .map(|peer| PeerPcb::new(peer))
.collect::<Result<_, _>>()?; .collect::<Result<_, _>>()?;
let tun = Arc::new(Mutex::new(tun)); let tun = Arc::new(RwLock::new(tun));
let pcbs = Arc::new(Mutex::new(pcbs)); let pcbs = Arc::new(pcbs);
Self { tun, pcbs } Self { tun, pcbs }
} }
@ -101,9 +102,9 @@ impl Interface {
log::debug!("starting loop..."); log::debug!("starting loop...");
let mut buf = [0u8; 3000]; let mut buf = [0u8; 3000];
let mut tun = tun.lock().await; let src = {
log::debug!("awaiting read..."); log::debug!("awaiting read...");
let src = match tun.recv(&mut buf[..]).await { let src = match tun.write().await.recv(&mut buf[..]).await {
Ok(len) => &buf[..len], Ok(len) => &buf[..len],
Err(e) => { Err(e) => {
log::error!("failed reading from interface: {}", e); log::error!("failed reading from interface: {}", e);
@ -112,8 +113,8 @@ impl Interface {
}; };
log::debug!("read {} bytes from interface", src.len()); log::debug!("read {} bytes from interface", src.len());
log::debug!("bytes: {:?}", src); log::debug!("bytes: {:?}", src);
src
let mut pcbs = pcbs.lock().await; };
let dst_addr = match Tunnel::dst_address(src) { let dst_addr = match Tunnel::dst_address(src) {
Some(addr) => addr, Some(addr) => addr,
@ -131,7 +132,7 @@ impl Interface {
log::debug!("found peer:{}", idx); log::debug!("found peer:{}", idx);
match pcbs.pcbs[idx].lock().await.send(src).await { match pcbs.pcbs[idx].read().await.send(src).await {
Ok(..) => { Ok(..) => {
log::debug!("sent packet to peer {}", dst_addr); log::debug!("sent packet to peer {}", dst_addr);
} }
@ -152,12 +153,13 @@ impl Interface {
let outgoing = tokio::task::spawn(outgoing); let outgoing = tokio::task::spawn(outgoing);
tsks.push(outgoing); tsks.push(outgoing);
{ {
let pcbs = self.pcbs.lock().await; let pcbs = self.pcbs;
for i in 0..pcbs.pcbs.len(){ for i in 0..pcbs.pcbs.len(){
let pcb = pcbs.pcbs[i].clone(); let mut pcb = pcbs.pcbs[i].clone();
let tun = tun.clone(); let tun = tun.clone();
let tsk = async move { let tsk = async move {
pcb.lock().await.run(tun).await.unwrap(); pcb.write().await.open_if_closed().await;
pcb.read().await.run(tun).await;
}; };
tsks.push(tokio::task::spawn(tsk)); tsks.push(tokio::task::spawn(tsk));
} }

View file

@ -1,12 +1,13 @@
use std::io;
use std::net::SocketAddr; use std::net::SocketAddr;
use std::rc::Rc; use std::rc::Rc;
use std::sync::Arc; use std::sync::Arc;
use anyhow::Error; use anyhow::{anyhow, Error};
use fehler::throws; use fehler::throws;
use ip_network::IpNetwork; use ip_network::IpNetwork;
use tokio::{net::UdpSocket, task::JoinHandle}; use tokio::{net::UdpSocket, task::JoinHandle};
use tokio::sync::Mutex; use tokio::sync::{Mutex, RwLock};
use super::{ use super::{
iface::PacketInterface, iface::PacketInterface,
@ -20,14 +21,14 @@ pub struct PeerPcb {
pub allowed_ips: Vec<IpNetwork>, pub allowed_ips: Vec<IpNetwork>,
pub handle: Option<JoinHandle<()>>, pub handle: Option<JoinHandle<()>>,
socket: Option<UdpSocket>, socket: Option<UdpSocket>,
tunnel: Tunnel, tunnel: RwLock<Tunnel>,
} }
impl PeerPcb { impl PeerPcb {
#[throws] #[throws]
pub fn new(peer: Peer) -> Self { pub fn new(peer: Peer) -> Self {
let tunnel = Tunnel::new(peer.private_key, peer.public_key, peer.preshared_key, None, 1, None) let tunnel = RwLock::new(Tunnel::new(peer.private_key, peer.public_key, peer.preshared_key, None, 1, None)
.map_err(|s| anyhow::anyhow!("{}", s))?; .map_err(|s| anyhow::anyhow!("{}", s))?);
Self { Self {
endpoint: peer.endpoint, endpoint: peer.endpoint,
@ -38,7 +39,7 @@ impl PeerPcb {
} }
} }
async fn open_if_closed(&mut self) -> Result<(), Error> { pub async fn open_if_closed(&mut self) -> Result<(), Error> {
if self.socket.is_none() { if self.socket.is_none() {
let socket = UdpSocket::bind("0.0.0.0:0").await?; let socket = UdpSocket::bind("0.0.0.0:0").await?;
socket.connect(self.endpoint).await?; socket.connect(self.endpoint).await?;
@ -47,23 +48,18 @@ impl PeerPcb {
Ok(()) Ok(())
} }
pub async fn run(&self, interface: Arc<Mutex<impl PacketInterface>>) -> Result<(), Error> { pub async fn run(&self, interface: Arc<RwLock<impl PacketInterface>>) -> Result<(), Error> {
let mut buf = [0u8; 3000]; let mut buf = [0u8; 3000];
log::debug!("starting read loop for pcb..."); log::debug!("starting read loop for pcb...");
loop { loop {
let Some(socket) = self.socket.as_ref() else { tracing::debug!("looping");
continue
let sock = match &self.socket {
None => {continue}
Some(sock) => {sock}
}; };
let packet = match socket.recv(&mut buf).await { let (len, addr) = sock.recv_from(&mut buf).await?;
Ok(s) => &buf[..s],
Err(e) => {
tracing::error!("eror receiving on peer socket: {}", e);
continue
}
};
let (len, addr) = socket.recv_from(&mut buf).await?;
tracing::debug!("received {} bytes from {}", len, addr); tracing::debug!("received {} bytes from {}", len, addr);
} }
@ -80,7 +76,7 @@ impl PeerPcb {
tracing::debug!("Decapsulating {} bytes from {}", len, addr); tracing::debug!("Decapsulating {} bytes from {}", len, addr);
tracing::debug!("{:?}", &res_dat); tracing::debug!("{:?}", &res_dat);
loop { loop {
match self.tunnel.decapsulate(None, res_dat, &mut buf[..]) { match self.tunnel.write().await.decapsulate(None, res_dat, &mut buf[..]) {
TunnResult::Done => { TunnResult::Done => {
tracing::debug!("Decapsulate done"); tracing::debug!("Decapsulate done");
break; break;
@ -110,15 +106,16 @@ impl PeerPcb {
Ok(self.socket.as_ref().expect("socket was just opened")) Ok(self.socket.as_ref().expect("socket was just opened"))
} }
pub async fn send(&mut self, src: &[u8]) -> Result<(), Error> {
pub async fn send(&self, src: &[u8]) -> Result<(), Error> {
let mut dst_buf = [0u8; 3000]; let mut dst_buf = [0u8; 3000];
match self.tunnel.encapsulate(src, &mut dst_buf[..]) { match self.tunnel.write().await.encapsulate(src, &mut dst_buf[..]) {
TunnResult::Done => {} TunnResult::Done => {}
TunnResult::Err(e) => { TunnResult::Err(e) => {
tracing::error!(message = "Encapsulate error", error = ?e) tracing::error!(message = "Encapsulate error", error = ?e)
} }
TunnResult::WriteToNetwork(packet) => { TunnResult::WriteToNetwork(packet) => {
let socket = self.socket().await?; let socket = &self.socket.as_ref().ok_or(anyhow!("socket not open"))?;
tracing::debug!("Our Encapsulated packet: {:?}", packet); tracing::debug!("Our Encapsulated packet: {:?}", packet);
socket.send(packet).await?; socket.send(packet).await?;
} }