Update locking to be interior to PeerPcb

This commit is contained in:
Conrad Kramer 2023-12-09 19:47:41 -08:00 committed by Jett Chen
parent 261f24d9ef
commit 3e5a01ffbe
6 changed files with 57 additions and 77 deletions

View file

@ -54,11 +54,12 @@ impl DaemonInstance {
warn!("Got start, but tun interface already up.");
}
RunState::Idle => {
debug!("Creating new TunInterface");
let retrieved = TunInterface::new(tun::TunInterface::retrieve().unwrap())?;
let raw = tun::TunInterface::retrieve().unwrap();
debug!("TunInterface retrieved: {:?}", raw.name()?);
let retrieved = TunInterface::new(raw)?;
let tun_if = Arc::new(RwLock::new(retrieved));
// let tun_if = Arc::new(RwLock::new(TunInterface::new(st.tun.open()?)?));
debug!("TunInterface created: {:?}", tun_if);
debug!("Setting tun_interface");
self.tun_interface = Some(tun_if.clone());

View file

@ -62,7 +62,8 @@ pub async fn daemon_main() -> Result<()> {
allowed_ips: vec![IpNetwork::V4(Ipv4Network::DEFAULT_ROUTE)],
}])?;
let mut inst = DaemonInstance::new(commands_rx, response_tx, Arc::new(RwLock::new(iface)));
let mut inst: DaemonInstance =
DaemonInstance::new(commands_rx, response_tx, Arc::new(RwLock::new(iface)));
tracing::info!("Starting daemon jobs...");

View file

@ -29,7 +29,7 @@ impl PacketInterface for tun::tokio::TunInterface {
}
struct IndexedPcbs {
pcbs: Vec<Arc<RwLock<PeerPcb>>>,
pcbs: Vec<Arc<PeerPcb>>,
allowed_ips: IpNetworkTable<usize>,
}
@ -46,7 +46,7 @@ impl IndexedPcbs {
for allowed_ip in pcb.allowed_ips.iter() {
self.allowed_ips.insert(allowed_ip.clone(), idx);
}
self.pcbs.insert(idx, Arc::new(RwLock::new(pcb)));
self.pcbs.insert(idx, Arc::new(pcb));
}
pub fn find(&self, addr: IpAddr) -> Option<usize> {
@ -55,7 +55,7 @@ impl IndexedPcbs {
}
pub async fn connect(&self, idx: usize, handle: JoinHandle<()>) {
self.pcbs[idx].write().await.handle = Some(handle);
self.pcbs[idx].handle.write().await.replace(handle);
}
}
@ -106,7 +106,7 @@ impl Interface {
let src = {
let src = match timeout(
Duration::from_millis(10),
tun.write().await.recv(&mut buf[..]),
tun.read().await.recv(&mut buf[..]),
)
.await
{
@ -138,9 +138,10 @@ impl Interface {
tracing::debug!("found peer:{}", idx);
match pcbs.pcbs[idx].read().await.send(src).await {
match pcbs.pcbs[idx].send(src).await {
Ok(..) => {
tracing::debug!("sent packet to peer {}", dst_addr);
let addr = pcbs.pcbs[idx].endpoint;
tracing::debug!("sent packet to peer {}", addr);
}
Err(e) => {
log::error!("failed to send packet {}", e);
@ -166,14 +167,11 @@ impl Interface {
let pcb = pcbs.pcbs[i].clone();
let tun = tun.clone();
let tsk = async move {
{
let r1 = pcb.write().await.open_if_closed().await;
if let Err(e) = r1 {
log::error!("failed to open pcb: {}", e);
return
}
if let Err(e) = pcb.open_if_closed().await {
log::error!("failed to open pcb: {}", e);
return
}
let r2 = pcb.read().await.run(tun).await;
let r2 = pcb.run(tun).await;
if let Err(e) = r2 {
log::error!("failed to run pcb: {}", e);
return

View file

@ -1,4 +1,8 @@
use std::{net::SocketAddr, sync::Arc};
use std::{
cell::{Cell, RefCell},
net::SocketAddr,
sync::Arc,
};
use anyhow::{anyhow, Error};
use fehler::throws;
@ -8,7 +12,6 @@ use tokio::{net::UdpSocket, sync::RwLock, task::JoinHandle};
use tun::tokio::TunInterface;
use super::{
iface::PacketInterface,
noise::{TunnResult, Tunnel},
Peer,
};
@ -17,8 +20,8 @@ use super::{
pub struct PeerPcb {
pub endpoint: SocketAddr,
pub allowed_ips: Vec<IpNetwork>,
pub handle: Option<JoinHandle<()>>,
socket: Option<UdpSocket>,
pub handle: RwLock<Option<JoinHandle<()>>>,
socket: RwLock<Option<UdpSocket>>,
tunnel: RwLock<Tunnel>,
}
@ -36,46 +39,35 @@ impl PeerPcb {
)
.map_err(|s| anyhow::anyhow!("{}", s))?,
);
Self {
endpoint: peer.endpoint,
allowed_ips: peer.allowed_ips,
handle: None,
socket: None,
handle: RwLock::new(None),
socket: RwLock::new(None),
tunnel,
}
}
pub async fn open_if_closed(&mut self) -> Result<(), Error> {
if self.socket.is_none() {
pub async fn open_if_closed(&self) -> Result<(), Error> {
if self.socket.read().await.is_none() {
let socket = UdpSocket::bind("0.0.0.0:0").await?;
socket.connect(self.endpoint).await?;
self.socket = Some(socket);
self.socket.write().await.replace(socket);
}
Ok(())
}
pub async fn run(&self, tun_interface: Arc<RwLock<TunInterface>>) -> Result<(), Error> {
let mut buf = [0u8; 3000];
tracing::debug!("starting read loop for pcb...");
loop {
tracing::debug!("waiting for packet");
let len = self.recv(&mut buf, tun_interface.clone()).await?;
tracing::debug!("received {} bytes", len);
}
}
pub async fn recv(
&self,
buf: &mut [u8],
tun_interface: Arc<RwLock<TunInterface>>,
) -> Result<usize, Error> {
tracing::debug!("starting read loop for pcb... for {:?}", &self);
let rid: i32 = random();
let mut buf: [u8; 3000] = [0u8; 3000];
tracing::debug!("start read loop {}", rid);
loop {
tracing::debug!("{}: waiting for packet", rid);
let Some(socket) = &self.socket else { continue };
let guard = self.socket.read().await;
let Some(socket) = guard.as_ref() else {
continue
};
let mut res_buf = [0; 1500];
// tracing::debug!("{} : waiting for readability on {:?}", rid, socket);
let len = match socket.recv(&mut res_buf).await {
@ -102,6 +94,7 @@ impl PeerPcb {
}
TunnResult::WriteToNetwork(packet) => {
tracing::debug!("WriteToNetwork: {:?}", packet);
self.open_if_closed().await?;
socket.send(packet).await?;
tracing::debug!("WriteToNetwork done");
res_dat = &[];
@ -119,15 +112,9 @@ impl PeerPcb {
}
}
}
return Ok(len)
}
}
pub async fn socket(&mut self) -> Result<&UdpSocket, Error> {
self.open_if_closed().await?;
Ok(self.socket.as_ref().expect("socket was just opened"))
}
pub async fn send(&self, src: &[u8]) -> Result<(), Error> {
let mut dst_buf = [0u8; 3000];
match self.tunnel.write().await.encapsulate(src, &mut dst_buf[..]) {
@ -136,7 +123,12 @@ impl PeerPcb {
tracing::error!(message = "Encapsulate error", error = ?e)
}
TunnResult::WriteToNetwork(packet) => {
let socket = &self.socket.as_ref().ok_or(anyhow!("socket not open"))?;
self.open_if_closed().await?;
let handle = self.socket.read().await;
let Some(socket) = handle.as_ref() else {
tracing::error!("No socket for peer");
return Ok(())
};
tracing::debug!("Our Encapsulated packet: {:?}", packet);
socket.send(packet).await?;
}