Switch to RwLock

This commit is contained in:
Jett Chen 2023-11-27 12:19:14 +08:00
parent 02f5a4ef74
commit 6c1c806401
2 changed files with 49 additions and 50 deletions

View file

@ -8,11 +8,12 @@ use ip_network_table::IpNetworkTable;
use log::log;
use tokio::{
join,
sync::Mutex,
sync::{Mutex, RwLock},
task::{self, JoinHandle},
};
use tun::tokio::TunInterface;
use futures::future::join_all;
use futures::FutureExt;
use super::{noise::Tunnel, pcb, Peer, PeerPcb};
@ -34,7 +35,7 @@ impl PacketInterface for tun::tokio::TunInterface {
}
struct IndexedPcbs {
pcbs: Vec<Arc<Mutex<PeerPcb>>>,
pcbs: Vec<Arc<RwLock<PeerPcb>>>,
allowed_ips: IpNetworkTable<usize>,
}
@ -51,16 +52,16 @@ impl IndexedPcbs {
for allowed_ip in pcb.allowed_ips.iter() {
self.allowed_ips.insert(allowed_ip.clone(), idx);
}
self.pcbs.insert(idx, Arc::new(Mutex::new(pcb)));
self.pcbs.insert(idx, Arc::new(RwLock::new(pcb)));
}
pub fn find(&mut self, addr: IpAddr) -> Option<usize> {
pub fn find(&self, addr: IpAddr) -> Option<usize> {
let (_, &idx) = self.allowed_ips.longest_match(addr)?;
Some(idx)
}
pub async fn connect(&mut self, idx: usize, handle: JoinHandle<()>) {
self.pcbs[idx].lock().await.handle = Some(handle);
pub async fn connect(&self, idx: usize, handle: JoinHandle<()>) {
self.pcbs[idx].write().await.handle = Some(handle);
}
}
@ -74,20 +75,20 @@ impl FromIterator<PeerPcb> for IndexedPcbs {
}
pub struct Interface {
tun: Arc<Mutex<TunInterface>>,
pcbs: Arc<Mutex<IndexedPcbs>>,
tun: Arc<RwLock<TunInterface>>,
pcbs: Arc<IndexedPcbs>,
}
impl Interface {
#[throws]
pub fn new<I: IntoIterator<Item = Peer>>(tun: TunInterface, peers: I) -> Self {
let pcbs: IndexedPcbs = peers
let mut pcbs: IndexedPcbs = peers
.into_iter()
.map(|peer| PeerPcb::new(peer))
.collect::<Result<_, _>>()?;
let tun = Arc::new(Mutex::new(tun));
let pcbs = Arc::new(Mutex::new(pcbs));
let tun = Arc::new(RwLock::new(tun));
let pcbs = Arc::new(pcbs);
Self { tun, pcbs }
}
@ -101,19 +102,19 @@ impl Interface {
log::debug!("starting loop...");
let mut buf = [0u8; 3000];
let mut tun = tun.lock().await;
log::debug!("awaiting read...");
let src = match tun.recv(&mut buf[..]).await {
Ok(len) => &buf[..len],
Err(e) => {
log::error!("failed reading from interface: {}", e);
continue
}
let src = {
log::debug!("awaiting read...");
let src = match tun.write().await.recv(&mut buf[..]).await {
Ok(len) => &buf[..len],
Err(e) => {
log::error!("failed reading from interface: {}", e);
continue
}
};
log::debug!("read {} bytes from interface", src.len());
log::debug!("bytes: {:?}", src);
src
};
log::debug!("read {} bytes from interface", src.len());
log::debug!("bytes: {:?}", src);
let mut pcbs = pcbs.lock().await;
let dst_addr = match Tunnel::dst_address(src) {
Some(addr) => addr,
@ -131,7 +132,7 @@ impl Interface {
log::debug!("found peer:{}", idx);
match pcbs.pcbs[idx].lock().await.send(src).await {
match pcbs.pcbs[idx].read().await.send(src).await {
Ok(..) => {
log::debug!("sent packet to peer {}", dst_addr);
}
@ -152,12 +153,13 @@ impl Interface {
let outgoing = tokio::task::spawn(outgoing);
tsks.push(outgoing);
{
let pcbs = self.pcbs.lock().await;
let pcbs = self.pcbs;
for i in 0..pcbs.pcbs.len(){
let pcb = pcbs.pcbs[i].clone();
let mut pcb = pcbs.pcbs[i].clone();
let tun = tun.clone();
let tsk = async move{
pcb.lock().await.run(tun).await.unwrap();
let tsk = async move {
pcb.write().await.open_if_closed().await;
pcb.read().await.run(tun).await;
};
tsks.push(tokio::task::spawn(tsk));
}

View file

@ -1,12 +1,13 @@
use std::io;
use std::net::SocketAddr;
use std::rc::Rc;
use std::sync::Arc;
use anyhow::Error;
use anyhow::{anyhow, Error};
use fehler::throws;
use ip_network::IpNetwork;
use tokio::{net::UdpSocket, task::JoinHandle};
use tokio::sync::Mutex;
use tokio::sync::{Mutex, RwLock};
use super::{
iface::PacketInterface,
@ -20,14 +21,14 @@ pub struct PeerPcb {
pub allowed_ips: Vec<IpNetwork>,
pub handle: Option<JoinHandle<()>>,
socket: Option<UdpSocket>,
tunnel: Tunnel,
tunnel: RwLock<Tunnel>,
}
impl PeerPcb {
#[throws]
pub fn new(peer: Peer) -> Self {
let tunnel = Tunnel::new(peer.private_key, peer.public_key, peer.preshared_key, None, 1, None)
.map_err(|s| anyhow::anyhow!("{}", s))?;
let tunnel = RwLock::new(Tunnel::new(peer.private_key, peer.public_key, peer.preshared_key, None, 1, None)
.map_err(|s| anyhow::anyhow!("{}", s))?);
Self {
endpoint: peer.endpoint,
@ -38,7 +39,7 @@ impl PeerPcb {
}
}
async fn open_if_closed(&mut self) -> Result<(), Error> {
pub async fn open_if_closed(&mut self) -> Result<(), Error> {
if self.socket.is_none() {
let socket = UdpSocket::bind("0.0.0.0:0").await?;
socket.connect(self.endpoint).await?;
@ -47,23 +48,18 @@ impl PeerPcb {
Ok(())
}
pub async fn run(&self, interface: Arc<Mutex<impl PacketInterface>>) -> Result<(), Error> {
pub async fn run(&self, interface: Arc<RwLock<impl PacketInterface>>) -> Result<(), Error> {
let mut buf = [0u8; 3000];
log::debug!("starting read loop for pcb...");
loop {
let Some(socket) = self.socket.as_ref() else {
continue
tracing::debug!("looping");
let sock = match &self.socket {
None => {continue}
Some(sock) => {sock}
};
let packet = match socket.recv(&mut buf).await {
Ok(s) => &buf[..s],
Err(e) => {
tracing::error!("eror receiving on peer socket: {}", e);
continue
}
};
let (len, addr) = socket.recv_from(&mut buf).await?;
let (len, addr) = sock.recv_from(&mut buf).await?;
tracing::debug!("received {} bytes from {}", len, addr);
}
@ -80,7 +76,7 @@ impl PeerPcb {
tracing::debug!("Decapsulating {} bytes from {}", len, addr);
tracing::debug!("{:?}", &res_dat);
loop {
match self.tunnel.decapsulate(None, res_dat, &mut buf[..]) {
match self.tunnel.write().await.decapsulate(None, res_dat, &mut buf[..]) {
TunnResult::Done => {
tracing::debug!("Decapsulate done");
break;
@ -110,15 +106,16 @@ impl PeerPcb {
Ok(self.socket.as_ref().expect("socket was just opened"))
}
pub async fn send(&mut self, src: &[u8]) -> Result<(), Error> {
pub async fn send(&self, src: &[u8]) -> Result<(), Error> {
let mut dst_buf = [0u8; 3000];
match self.tunnel.encapsulate(src, &mut dst_buf[..]) {
match self.tunnel.write().await.encapsulate(src, &mut dst_buf[..]) {
TunnResult::Done => {}
TunnResult::Err(e) => {
tracing::error!(message = "Encapsulate error", error = ?e)
}
TunnResult::WriteToNetwork(packet) => {
let socket = self.socket().await?;
let socket = &self.socket.as_ref().ok_or(anyhow!("socket not open"))?;
tracing::debug!("Our Encapsulated packet: {:?}", packet);
socket.send(packet).await?;
}