Switch to RwLock
This commit is contained in:
parent
02f5a4ef74
commit
6c1c806401
2 changed files with 49 additions and 50 deletions
|
|
@ -8,11 +8,12 @@ use ip_network_table::IpNetworkTable;
|
|||
use log::log;
|
||||
use tokio::{
|
||||
join,
|
||||
sync::Mutex,
|
||||
sync::{Mutex, RwLock},
|
||||
task::{self, JoinHandle},
|
||||
};
|
||||
use tun::tokio::TunInterface;
|
||||
use futures::future::join_all;
|
||||
use futures::FutureExt;
|
||||
|
||||
use super::{noise::Tunnel, pcb, Peer, PeerPcb};
|
||||
|
||||
|
|
@ -34,7 +35,7 @@ impl PacketInterface for tun::tokio::TunInterface {
|
|||
}
|
||||
|
||||
struct IndexedPcbs {
|
||||
pcbs: Vec<Arc<Mutex<PeerPcb>>>,
|
||||
pcbs: Vec<Arc<RwLock<PeerPcb>>>,
|
||||
allowed_ips: IpNetworkTable<usize>,
|
||||
}
|
||||
|
||||
|
|
@ -51,16 +52,16 @@ impl IndexedPcbs {
|
|||
for allowed_ip in pcb.allowed_ips.iter() {
|
||||
self.allowed_ips.insert(allowed_ip.clone(), idx);
|
||||
}
|
||||
self.pcbs.insert(idx, Arc::new(Mutex::new(pcb)));
|
||||
self.pcbs.insert(idx, Arc::new(RwLock::new(pcb)));
|
||||
}
|
||||
|
||||
pub fn find(&mut self, addr: IpAddr) -> Option<usize> {
|
||||
pub fn find(&self, addr: IpAddr) -> Option<usize> {
|
||||
let (_, &idx) = self.allowed_ips.longest_match(addr)?;
|
||||
Some(idx)
|
||||
}
|
||||
|
||||
pub async fn connect(&mut self, idx: usize, handle: JoinHandle<()>) {
|
||||
self.pcbs[idx].lock().await.handle = Some(handle);
|
||||
pub async fn connect(&self, idx: usize, handle: JoinHandle<()>) {
|
||||
self.pcbs[idx].write().await.handle = Some(handle);
|
||||
}
|
||||
}
|
||||
|
||||
|
|
@ -74,20 +75,20 @@ impl FromIterator<PeerPcb> for IndexedPcbs {
|
|||
}
|
||||
|
||||
pub struct Interface {
|
||||
tun: Arc<Mutex<TunInterface>>,
|
||||
pcbs: Arc<Mutex<IndexedPcbs>>,
|
||||
tun: Arc<RwLock<TunInterface>>,
|
||||
pcbs: Arc<IndexedPcbs>,
|
||||
}
|
||||
|
||||
impl Interface {
|
||||
#[throws]
|
||||
pub fn new<I: IntoIterator<Item = Peer>>(tun: TunInterface, peers: I) -> Self {
|
||||
let pcbs: IndexedPcbs = peers
|
||||
let mut pcbs: IndexedPcbs = peers
|
||||
.into_iter()
|
||||
.map(|peer| PeerPcb::new(peer))
|
||||
.collect::<Result<_, _>>()?;
|
||||
|
||||
let tun = Arc::new(Mutex::new(tun));
|
||||
let pcbs = Arc::new(Mutex::new(pcbs));
|
||||
let tun = Arc::new(RwLock::new(tun));
|
||||
let pcbs = Arc::new(pcbs);
|
||||
Self { tun, pcbs }
|
||||
}
|
||||
|
||||
|
|
@ -101,9 +102,9 @@ impl Interface {
|
|||
log::debug!("starting loop...");
|
||||
let mut buf = [0u8; 3000];
|
||||
|
||||
let mut tun = tun.lock().await;
|
||||
let src = {
|
||||
log::debug!("awaiting read...");
|
||||
let src = match tun.recv(&mut buf[..]).await {
|
||||
let src = match tun.write().await.recv(&mut buf[..]).await {
|
||||
Ok(len) => &buf[..len],
|
||||
Err(e) => {
|
||||
log::error!("failed reading from interface: {}", e);
|
||||
|
|
@ -112,8 +113,8 @@ impl Interface {
|
|||
};
|
||||
log::debug!("read {} bytes from interface", src.len());
|
||||
log::debug!("bytes: {:?}", src);
|
||||
|
||||
let mut pcbs = pcbs.lock().await;
|
||||
src
|
||||
};
|
||||
|
||||
let dst_addr = match Tunnel::dst_address(src) {
|
||||
Some(addr) => addr,
|
||||
|
|
@ -131,7 +132,7 @@ impl Interface {
|
|||
|
||||
log::debug!("found peer:{}", idx);
|
||||
|
||||
match pcbs.pcbs[idx].lock().await.send(src).await {
|
||||
match pcbs.pcbs[idx].read().await.send(src).await {
|
||||
Ok(..) => {
|
||||
log::debug!("sent packet to peer {}", dst_addr);
|
||||
}
|
||||
|
|
@ -152,12 +153,13 @@ impl Interface {
|
|||
let outgoing = tokio::task::spawn(outgoing);
|
||||
tsks.push(outgoing);
|
||||
{
|
||||
let pcbs = self.pcbs.lock().await;
|
||||
let pcbs = self.pcbs;
|
||||
for i in 0..pcbs.pcbs.len(){
|
||||
let pcb = pcbs.pcbs[i].clone();
|
||||
let mut pcb = pcbs.pcbs[i].clone();
|
||||
let tun = tun.clone();
|
||||
let tsk = async move{
|
||||
pcb.lock().await.run(tun).await.unwrap();
|
||||
let tsk = async move {
|
||||
pcb.write().await.open_if_closed().await;
|
||||
pcb.read().await.run(tun).await;
|
||||
};
|
||||
tsks.push(tokio::task::spawn(tsk));
|
||||
}
|
||||
|
|
|
|||
|
|
@ -1,12 +1,13 @@
|
|||
use std::io;
|
||||
use std::net::SocketAddr;
|
||||
use std::rc::Rc;
|
||||
use std::sync::Arc;
|
||||
|
||||
use anyhow::Error;
|
||||
use anyhow::{anyhow, Error};
|
||||
use fehler::throws;
|
||||
use ip_network::IpNetwork;
|
||||
use tokio::{net::UdpSocket, task::JoinHandle};
|
||||
use tokio::sync::Mutex;
|
||||
use tokio::sync::{Mutex, RwLock};
|
||||
|
||||
use super::{
|
||||
iface::PacketInterface,
|
||||
|
|
@ -20,14 +21,14 @@ pub struct PeerPcb {
|
|||
pub allowed_ips: Vec<IpNetwork>,
|
||||
pub handle: Option<JoinHandle<()>>,
|
||||
socket: Option<UdpSocket>,
|
||||
tunnel: Tunnel,
|
||||
tunnel: RwLock<Tunnel>,
|
||||
}
|
||||
|
||||
impl PeerPcb {
|
||||
#[throws]
|
||||
pub fn new(peer: Peer) -> Self {
|
||||
let tunnel = Tunnel::new(peer.private_key, peer.public_key, peer.preshared_key, None, 1, None)
|
||||
.map_err(|s| anyhow::anyhow!("{}", s))?;
|
||||
let tunnel = RwLock::new(Tunnel::new(peer.private_key, peer.public_key, peer.preshared_key, None, 1, None)
|
||||
.map_err(|s| anyhow::anyhow!("{}", s))?);
|
||||
|
||||
Self {
|
||||
endpoint: peer.endpoint,
|
||||
|
|
@ -38,7 +39,7 @@ impl PeerPcb {
|
|||
}
|
||||
}
|
||||
|
||||
async fn open_if_closed(&mut self) -> Result<(), Error> {
|
||||
pub async fn open_if_closed(&mut self) -> Result<(), Error> {
|
||||
if self.socket.is_none() {
|
||||
let socket = UdpSocket::bind("0.0.0.0:0").await?;
|
||||
socket.connect(self.endpoint).await?;
|
||||
|
|
@ -47,23 +48,18 @@ impl PeerPcb {
|
|||
Ok(())
|
||||
}
|
||||
|
||||
pub async fn run(&self, interface: Arc<Mutex<impl PacketInterface>>) -> Result<(), Error> {
|
||||
pub async fn run(&self, interface: Arc<RwLock<impl PacketInterface>>) -> Result<(), Error> {
|
||||
let mut buf = [0u8; 3000];
|
||||
log::debug!("starting read loop for pcb...");
|
||||
loop {
|
||||
let Some(socket) = self.socket.as_ref() else {
|
||||
continue
|
||||
tracing::debug!("looping");
|
||||
|
||||
let sock = match &self.socket {
|
||||
None => {continue}
|
||||
Some(sock) => {sock}
|
||||
};
|
||||
|
||||
let packet = match socket.recv(&mut buf).await {
|
||||
Ok(s) => &buf[..s],
|
||||
Err(e) => {
|
||||
tracing::error!("eror receiving on peer socket: {}", e);
|
||||
continue
|
||||
}
|
||||
};
|
||||
|
||||
let (len, addr) = socket.recv_from(&mut buf).await?;
|
||||
let (len, addr) = sock.recv_from(&mut buf).await?;
|
||||
|
||||
tracing::debug!("received {} bytes from {}", len, addr);
|
||||
}
|
||||
|
|
@ -80,7 +76,7 @@ impl PeerPcb {
|
|||
tracing::debug!("Decapsulating {} bytes from {}", len, addr);
|
||||
tracing::debug!("{:?}", &res_dat);
|
||||
loop {
|
||||
match self.tunnel.decapsulate(None, res_dat, &mut buf[..]) {
|
||||
match self.tunnel.write().await.decapsulate(None, res_dat, &mut buf[..]) {
|
||||
TunnResult::Done => {
|
||||
tracing::debug!("Decapsulate done");
|
||||
break;
|
||||
|
|
@ -110,15 +106,16 @@ impl PeerPcb {
|
|||
Ok(self.socket.as_ref().expect("socket was just opened"))
|
||||
}
|
||||
|
||||
pub async fn send(&mut self, src: &[u8]) -> Result<(), Error> {
|
||||
|
||||
pub async fn send(&self, src: &[u8]) -> Result<(), Error> {
|
||||
let mut dst_buf = [0u8; 3000];
|
||||
match self.tunnel.encapsulate(src, &mut dst_buf[..]) {
|
||||
match self.tunnel.write().await.encapsulate(src, &mut dst_buf[..]) {
|
||||
TunnResult::Done => {}
|
||||
TunnResult::Err(e) => {
|
||||
tracing::error!(message = "Encapsulate error", error = ?e)
|
||||
}
|
||||
TunnResult::WriteToNetwork(packet) => {
|
||||
let socket = self.socket().await?;
|
||||
let socket = &self.socket.as_ref().ok_or(anyhow!("socket not open"))?;
|
||||
tracing::debug!("Our Encapsulated packet: {:?}", packet);
|
||||
socket.send(packet).await?;
|
||||
}
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue