Switch to RwLock
This commit is contained in:
parent
c7c4e5779c
commit
73b3136597
2 changed files with 49 additions and 50 deletions
|
|
@ -8,11 +8,12 @@ use ip_network_table::IpNetworkTable;
|
||||||
use log::log;
|
use log::log;
|
||||||
use tokio::{
|
use tokio::{
|
||||||
join,
|
join,
|
||||||
sync::Mutex,
|
sync::{Mutex, RwLock},
|
||||||
task::{self, JoinHandle},
|
task::{self, JoinHandle},
|
||||||
};
|
};
|
||||||
use tun::tokio::TunInterface;
|
use tun::tokio::TunInterface;
|
||||||
use futures::future::join_all;
|
use futures::future::join_all;
|
||||||
|
use futures::FutureExt;
|
||||||
|
|
||||||
use super::{noise::Tunnel, pcb, Peer, PeerPcb};
|
use super::{noise::Tunnel, pcb, Peer, PeerPcb};
|
||||||
|
|
||||||
|
|
@ -34,7 +35,7 @@ impl PacketInterface for tun::tokio::TunInterface {
|
||||||
}
|
}
|
||||||
|
|
||||||
struct IndexedPcbs {
|
struct IndexedPcbs {
|
||||||
pcbs: Vec<Arc<Mutex<PeerPcb>>>,
|
pcbs: Vec<Arc<RwLock<PeerPcb>>>,
|
||||||
allowed_ips: IpNetworkTable<usize>,
|
allowed_ips: IpNetworkTable<usize>,
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
@ -51,16 +52,16 @@ impl IndexedPcbs {
|
||||||
for allowed_ip in pcb.allowed_ips.iter() {
|
for allowed_ip in pcb.allowed_ips.iter() {
|
||||||
self.allowed_ips.insert(allowed_ip.clone(), idx);
|
self.allowed_ips.insert(allowed_ip.clone(), idx);
|
||||||
}
|
}
|
||||||
self.pcbs.insert(idx, Arc::new(Mutex::new(pcb)));
|
self.pcbs.insert(idx, Arc::new(RwLock::new(pcb)));
|
||||||
}
|
}
|
||||||
|
|
||||||
pub fn find(&mut self, addr: IpAddr) -> Option<usize> {
|
pub fn find(&self, addr: IpAddr) -> Option<usize> {
|
||||||
let (_, &idx) = self.allowed_ips.longest_match(addr)?;
|
let (_, &idx) = self.allowed_ips.longest_match(addr)?;
|
||||||
Some(idx)
|
Some(idx)
|
||||||
}
|
}
|
||||||
|
|
||||||
pub async fn connect(&mut self, idx: usize, handle: JoinHandle<()>) {
|
pub async fn connect(&self, idx: usize, handle: JoinHandle<()>) {
|
||||||
self.pcbs[idx].lock().await.handle = Some(handle);
|
self.pcbs[idx].write().await.handle = Some(handle);
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
@ -74,20 +75,20 @@ impl FromIterator<PeerPcb> for IndexedPcbs {
|
||||||
}
|
}
|
||||||
|
|
||||||
pub struct Interface {
|
pub struct Interface {
|
||||||
tun: Arc<Mutex<TunInterface>>,
|
tun: Arc<RwLock<TunInterface>>,
|
||||||
pcbs: Arc<Mutex<IndexedPcbs>>,
|
pcbs: Arc<IndexedPcbs>,
|
||||||
}
|
}
|
||||||
|
|
||||||
impl Interface {
|
impl Interface {
|
||||||
#[throws]
|
#[throws]
|
||||||
pub fn new<I: IntoIterator<Item = Peer>>(tun: TunInterface, peers: I) -> Self {
|
pub fn new<I: IntoIterator<Item = Peer>>(tun: TunInterface, peers: I) -> Self {
|
||||||
let pcbs: IndexedPcbs = peers
|
let mut pcbs: IndexedPcbs = peers
|
||||||
.into_iter()
|
.into_iter()
|
||||||
.map(|peer| PeerPcb::new(peer))
|
.map(|peer| PeerPcb::new(peer))
|
||||||
.collect::<Result<_, _>>()?;
|
.collect::<Result<_, _>>()?;
|
||||||
|
|
||||||
let tun = Arc::new(Mutex::new(tun));
|
let tun = Arc::new(RwLock::new(tun));
|
||||||
let pcbs = Arc::new(Mutex::new(pcbs));
|
let pcbs = Arc::new(pcbs);
|
||||||
Self { tun, pcbs }
|
Self { tun, pcbs }
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
@ -101,9 +102,9 @@ impl Interface {
|
||||||
log::debug!("starting loop...");
|
log::debug!("starting loop...");
|
||||||
let mut buf = [0u8; 3000];
|
let mut buf = [0u8; 3000];
|
||||||
|
|
||||||
let mut tun = tun.lock().await;
|
let src = {
|
||||||
log::debug!("awaiting read...");
|
log::debug!("awaiting read...");
|
||||||
let src = match tun.recv(&mut buf[..]).await {
|
let src = match tun.write().await.recv(&mut buf[..]).await {
|
||||||
Ok(len) => &buf[..len],
|
Ok(len) => &buf[..len],
|
||||||
Err(e) => {
|
Err(e) => {
|
||||||
log::error!("failed reading from interface: {}", e);
|
log::error!("failed reading from interface: {}", e);
|
||||||
|
|
@ -112,8 +113,8 @@ impl Interface {
|
||||||
};
|
};
|
||||||
log::debug!("read {} bytes from interface", src.len());
|
log::debug!("read {} bytes from interface", src.len());
|
||||||
log::debug!("bytes: {:?}", src);
|
log::debug!("bytes: {:?}", src);
|
||||||
|
src
|
||||||
let mut pcbs = pcbs.lock().await;
|
};
|
||||||
|
|
||||||
let dst_addr = match Tunnel::dst_address(src) {
|
let dst_addr = match Tunnel::dst_address(src) {
|
||||||
Some(addr) => addr,
|
Some(addr) => addr,
|
||||||
|
|
@ -131,7 +132,7 @@ impl Interface {
|
||||||
|
|
||||||
log::debug!("found peer:{}", idx);
|
log::debug!("found peer:{}", idx);
|
||||||
|
|
||||||
match pcbs.pcbs[idx].lock().await.send(src).await {
|
match pcbs.pcbs[idx].read().await.send(src).await {
|
||||||
Ok(..) => {
|
Ok(..) => {
|
||||||
log::debug!("sent packet to peer {}", dst_addr);
|
log::debug!("sent packet to peer {}", dst_addr);
|
||||||
}
|
}
|
||||||
|
|
@ -152,12 +153,13 @@ impl Interface {
|
||||||
let outgoing = tokio::task::spawn(outgoing);
|
let outgoing = tokio::task::spawn(outgoing);
|
||||||
tsks.push(outgoing);
|
tsks.push(outgoing);
|
||||||
{
|
{
|
||||||
let pcbs = self.pcbs.lock().await;
|
let pcbs = self.pcbs;
|
||||||
for i in 0..pcbs.pcbs.len(){
|
for i in 0..pcbs.pcbs.len(){
|
||||||
let pcb = pcbs.pcbs[i].clone();
|
let mut pcb = pcbs.pcbs[i].clone();
|
||||||
let tun = tun.clone();
|
let tun = tun.clone();
|
||||||
let tsk = async move {
|
let tsk = async move {
|
||||||
pcb.lock().await.run(tun).await.unwrap();
|
pcb.write().await.open_if_closed().await;
|
||||||
|
pcb.read().await.run(tun).await;
|
||||||
};
|
};
|
||||||
tsks.push(tokio::task::spawn(tsk));
|
tsks.push(tokio::task::spawn(tsk));
|
||||||
}
|
}
|
||||||
|
|
|
||||||
|
|
@ -1,12 +1,13 @@
|
||||||
|
use std::io;
|
||||||
use std::net::SocketAddr;
|
use std::net::SocketAddr;
|
||||||
use std::rc::Rc;
|
use std::rc::Rc;
|
||||||
use std::sync::Arc;
|
use std::sync::Arc;
|
||||||
|
|
||||||
use anyhow::Error;
|
use anyhow::{anyhow, Error};
|
||||||
use fehler::throws;
|
use fehler::throws;
|
||||||
use ip_network::IpNetwork;
|
use ip_network::IpNetwork;
|
||||||
use tokio::{net::UdpSocket, task::JoinHandle};
|
use tokio::{net::UdpSocket, task::JoinHandle};
|
||||||
use tokio::sync::Mutex;
|
use tokio::sync::{Mutex, RwLock};
|
||||||
|
|
||||||
use super::{
|
use super::{
|
||||||
iface::PacketInterface,
|
iface::PacketInterface,
|
||||||
|
|
@ -20,14 +21,14 @@ pub struct PeerPcb {
|
||||||
pub allowed_ips: Vec<IpNetwork>,
|
pub allowed_ips: Vec<IpNetwork>,
|
||||||
pub handle: Option<JoinHandle<()>>,
|
pub handle: Option<JoinHandle<()>>,
|
||||||
socket: Option<UdpSocket>,
|
socket: Option<UdpSocket>,
|
||||||
tunnel: Tunnel,
|
tunnel: RwLock<Tunnel>,
|
||||||
}
|
}
|
||||||
|
|
||||||
impl PeerPcb {
|
impl PeerPcb {
|
||||||
#[throws]
|
#[throws]
|
||||||
pub fn new(peer: Peer) -> Self {
|
pub fn new(peer: Peer) -> Self {
|
||||||
let tunnel = Tunnel::new(peer.private_key, peer.public_key, peer.preshared_key, None, 1, None)
|
let tunnel = RwLock::new(Tunnel::new(peer.private_key, peer.public_key, peer.preshared_key, None, 1, None)
|
||||||
.map_err(|s| anyhow::anyhow!("{}", s))?;
|
.map_err(|s| anyhow::anyhow!("{}", s))?);
|
||||||
|
|
||||||
Self {
|
Self {
|
||||||
endpoint: peer.endpoint,
|
endpoint: peer.endpoint,
|
||||||
|
|
@ -38,7 +39,7 @@ impl PeerPcb {
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
async fn open_if_closed(&mut self) -> Result<(), Error> {
|
pub async fn open_if_closed(&mut self) -> Result<(), Error> {
|
||||||
if self.socket.is_none() {
|
if self.socket.is_none() {
|
||||||
let socket = UdpSocket::bind("0.0.0.0:0").await?;
|
let socket = UdpSocket::bind("0.0.0.0:0").await?;
|
||||||
socket.connect(self.endpoint).await?;
|
socket.connect(self.endpoint).await?;
|
||||||
|
|
@ -47,23 +48,18 @@ impl PeerPcb {
|
||||||
Ok(())
|
Ok(())
|
||||||
}
|
}
|
||||||
|
|
||||||
pub async fn run(&self, interface: Arc<Mutex<impl PacketInterface>>) -> Result<(), Error> {
|
pub async fn run(&self, interface: Arc<RwLock<impl PacketInterface>>) -> Result<(), Error> {
|
||||||
let mut buf = [0u8; 3000];
|
let mut buf = [0u8; 3000];
|
||||||
log::debug!("starting read loop for pcb...");
|
log::debug!("starting read loop for pcb...");
|
||||||
loop {
|
loop {
|
||||||
let Some(socket) = self.socket.as_ref() else {
|
tracing::debug!("looping");
|
||||||
continue
|
|
||||||
|
let sock = match &self.socket {
|
||||||
|
None => {continue}
|
||||||
|
Some(sock) => {sock}
|
||||||
};
|
};
|
||||||
|
|
||||||
let packet = match socket.recv(&mut buf).await {
|
let (len, addr) = sock.recv_from(&mut buf).await?;
|
||||||
Ok(s) => &buf[..s],
|
|
||||||
Err(e) => {
|
|
||||||
tracing::error!("eror receiving on peer socket: {}", e);
|
|
||||||
continue
|
|
||||||
}
|
|
||||||
};
|
|
||||||
|
|
||||||
let (len, addr) = socket.recv_from(&mut buf).await?;
|
|
||||||
|
|
||||||
tracing::debug!("received {} bytes from {}", len, addr);
|
tracing::debug!("received {} bytes from {}", len, addr);
|
||||||
}
|
}
|
||||||
|
|
@ -80,7 +76,7 @@ impl PeerPcb {
|
||||||
tracing::debug!("Decapsulating {} bytes from {}", len, addr);
|
tracing::debug!("Decapsulating {} bytes from {}", len, addr);
|
||||||
tracing::debug!("{:?}", &res_dat);
|
tracing::debug!("{:?}", &res_dat);
|
||||||
loop {
|
loop {
|
||||||
match self.tunnel.decapsulate(None, res_dat, &mut buf[..]) {
|
match self.tunnel.write().await.decapsulate(None, res_dat, &mut buf[..]) {
|
||||||
TunnResult::Done => {
|
TunnResult::Done => {
|
||||||
tracing::debug!("Decapsulate done");
|
tracing::debug!("Decapsulate done");
|
||||||
break;
|
break;
|
||||||
|
|
@ -110,15 +106,16 @@ impl PeerPcb {
|
||||||
Ok(self.socket.as_ref().expect("socket was just opened"))
|
Ok(self.socket.as_ref().expect("socket was just opened"))
|
||||||
}
|
}
|
||||||
|
|
||||||
pub async fn send(&mut self, src: &[u8]) -> Result<(), Error> {
|
|
||||||
|
pub async fn send(&self, src: &[u8]) -> Result<(), Error> {
|
||||||
let mut dst_buf = [0u8; 3000];
|
let mut dst_buf = [0u8; 3000];
|
||||||
match self.tunnel.encapsulate(src, &mut dst_buf[..]) {
|
match self.tunnel.write().await.encapsulate(src, &mut dst_buf[..]) {
|
||||||
TunnResult::Done => {}
|
TunnResult::Done => {}
|
||||||
TunnResult::Err(e) => {
|
TunnResult::Err(e) => {
|
||||||
tracing::error!(message = "Encapsulate error", error = ?e)
|
tracing::error!(message = "Encapsulate error", error = ?e)
|
||||||
}
|
}
|
||||||
TunnResult::WriteToNetwork(packet) => {
|
TunnResult::WriteToNetwork(packet) => {
|
||||||
let socket = self.socket().await?;
|
let socket = &self.socket.as_ref().ok_or(anyhow!("socket not open"))?;
|
||||||
tracing::debug!("Our Encapsulated packet: {:?}", packet);
|
tracing::debug!("Our Encapsulated packet: {:?}", packet);
|
||||||
socket.send(packet).await?;
|
socket.send(packet).await?;
|
||||||
}
|
}
|
||||||
|
|
|
||||||
Loading…
Add table
Add a link
Reference in a new issue