Update locking to be interior to PeerPcb

This commit is contained in:
Conrad Kramer 2023-12-09 19:47:41 -08:00
parent 30cd00fc2b
commit 4408e9aca8
6 changed files with 57 additions and 77 deletions

View file

@ -54,11 +54,12 @@ impl DaemonInstance {
warn!("Got start, but tun interface already up.");
}
RunState::Idle => {
debug!("Creating new TunInterface");
let retrieved = TunInterface::new(tun::TunInterface::retrieve().unwrap())?;
let raw = tun::TunInterface::retrieve().unwrap();
debug!("TunInterface retrieved: {:?}", raw.name()?);
let retrieved = TunInterface::new(raw)?;
let tun_if = Arc::new(RwLock::new(retrieved));
// let tun_if = Arc::new(RwLock::new(TunInterface::new(st.tun.open()?)?));
debug!("TunInterface created: {:?}", tun_if);
debug!("Setting tun_interface");
self.tun_interface = Some(tun_if.clone());

View file

@ -62,7 +62,8 @@ pub async fn daemon_main() -> Result<()> {
allowed_ips: vec![IpNetwork::V4(Ipv4Network::DEFAULT_ROUTE)],
}])?;
let mut inst = DaemonInstance::new(commands_rx, response_tx, Arc::new(RwLock::new(iface)));
let mut inst: DaemonInstance =
DaemonInstance::new(commands_rx, response_tx, Arc::new(RwLock::new(iface)));
tracing::info!("Starting daemon jobs...");

View file

@ -29,7 +29,7 @@ impl PacketInterface for tun::tokio::TunInterface {
}
struct IndexedPcbs {
pcbs: Vec<Arc<RwLock<PeerPcb>>>,
pcbs: Vec<Arc<PeerPcb>>,
allowed_ips: IpNetworkTable<usize>,
}
@ -46,7 +46,7 @@ impl IndexedPcbs {
for allowed_ip in pcb.allowed_ips.iter() {
self.allowed_ips.insert(allowed_ip.clone(), idx);
}
self.pcbs.insert(idx, Arc::new(RwLock::new(pcb)));
self.pcbs.insert(idx, Arc::new(pcb));
}
pub fn find(&self, addr: IpAddr) -> Option<usize> {
@ -55,7 +55,7 @@ impl IndexedPcbs {
}
pub async fn connect(&self, idx: usize, handle: JoinHandle<()>) {
self.pcbs[idx].write().await.handle = Some(handle);
self.pcbs[idx].handle.write().await.replace(handle);
}
}
@ -106,7 +106,7 @@ impl Interface {
let src = {
let src = match timeout(
Duration::from_millis(10),
tun.write().await.recv(&mut buf[..]),
tun.read().await.recv(&mut buf[..]),
)
.await
{
@ -138,9 +138,10 @@ impl Interface {
tracing::debug!("found peer:{}", idx);
match pcbs.pcbs[idx].read().await.send(src).await {
match pcbs.pcbs[idx].send(src).await {
Ok(..) => {
tracing::debug!("sent packet to peer {}", dst_addr);
let addr = pcbs.pcbs[idx].endpoint;
tracing::debug!("sent packet to peer {}", addr);
}
Err(e) => {
log::error!("failed to send packet {}", e);
@ -166,14 +167,11 @@ impl Interface {
let pcb = pcbs.pcbs[i].clone();
let tun = tun.clone();
let tsk = async move {
{
let r1 = pcb.write().await.open_if_closed().await;
if let Err(e) = r1 {
if let Err(e) = pcb.open_if_closed().await {
log::error!("failed to open pcb: {}", e);
return
}
}
let r2 = pcb.read().await.run(tun).await;
let r2 = pcb.run(tun).await;
if let Err(e) = r2 {
log::error!("failed to run pcb: {}", e);
return

View file

@ -1,4 +1,8 @@
use std::{net::SocketAddr, sync::Arc};
use std::{
cell::{Cell, RefCell},
net::SocketAddr,
sync::Arc,
};
use anyhow::{anyhow, Error};
use fehler::throws;
@ -8,7 +12,6 @@ use tokio::{net::UdpSocket, sync::RwLock, task::JoinHandle};
use tun::tokio::TunInterface;
use super::{
iface::PacketInterface,
noise::{TunnResult, Tunnel},
Peer,
};
@ -17,8 +20,8 @@ use super::{
pub struct PeerPcb {
pub endpoint: SocketAddr,
pub allowed_ips: Vec<IpNetwork>,
pub handle: Option<JoinHandle<()>>,
socket: Option<UdpSocket>,
pub handle: RwLock<Option<JoinHandle<()>>>,
socket: RwLock<Option<UdpSocket>>,
tunnel: RwLock<Tunnel>,
}
@ -36,46 +39,35 @@ impl PeerPcb {
)
.map_err(|s| anyhow::anyhow!("{}", s))?,
);
Self {
endpoint: peer.endpoint,
allowed_ips: peer.allowed_ips,
handle: None,
socket: None,
handle: RwLock::new(None),
socket: RwLock::new(None),
tunnel,
}
}
pub async fn open_if_closed(&mut self) -> Result<(), Error> {
if self.socket.is_none() {
pub async fn open_if_closed(&self) -> Result<(), Error> {
if self.socket.read().await.is_none() {
let socket = UdpSocket::bind("0.0.0.0:0").await?;
socket.connect(self.endpoint).await?;
self.socket = Some(socket);
self.socket.write().await.replace(socket);
}
Ok(())
}
pub async fn run(&self, tun_interface: Arc<RwLock<TunInterface>>) -> Result<(), Error> {
let mut buf = [0u8; 3000];
tracing::debug!("starting read loop for pcb...");
loop {
tracing::debug!("waiting for packet");
let len = self.recv(&mut buf, tun_interface.clone()).await?;
tracing::debug!("received {} bytes", len);
}
}
pub async fn recv(
&self,
buf: &mut [u8],
tun_interface: Arc<RwLock<TunInterface>>,
) -> Result<usize, Error> {
tracing::debug!("starting read loop for pcb... for {:?}", &self);
let rid: i32 = random();
let mut buf: [u8; 3000] = [0u8; 3000];
tracing::debug!("start read loop {}", rid);
loop {
tracing::debug!("{}: waiting for packet", rid);
let Some(socket) = &self.socket else { continue };
let guard = self.socket.read().await;
let Some(socket) = guard.as_ref() else {
continue
};
let mut res_buf = [0; 1500];
// tracing::debug!("{} : waiting for readability on {:?}", rid, socket);
let len = match socket.recv(&mut res_buf).await {
@ -102,6 +94,7 @@ impl PeerPcb {
}
TunnResult::WriteToNetwork(packet) => {
tracing::debug!("WriteToNetwork: {:?}", packet);
self.open_if_closed().await?;
socket.send(packet).await?;
tracing::debug!("WriteToNetwork done");
res_dat = &[];
@ -119,15 +112,9 @@ impl PeerPcb {
}
}
}
return Ok(len)
}
}
pub async fn socket(&mut self) -> Result<&UdpSocket, Error> {
self.open_if_closed().await?;
Ok(self.socket.as_ref().expect("socket was just opened"))
}
pub async fn send(&self, src: &[u8]) -> Result<(), Error> {
let mut dst_buf = [0u8; 3000];
match self.tunnel.write().await.encapsulate(src, &mut dst_buf[..]) {
@ -136,7 +123,12 @@ impl PeerPcb {
tracing::error!(message = "Encapsulate error", error = ?e)
}
TunnResult::WriteToNetwork(packet) => {
let socket = &self.socket.as_ref().ok_or(anyhow!("socket not open"))?;
self.open_if_closed().await?;
let handle = self.socket.read().await;
let Some(socket) = handle.as_ref() else {
tracing::error!("No socket for peer");
return Ok(())
};
tracing::debug!("Our Encapsulated packet: {:?}", packet);
socket.send(packet).await?;
}

View file

@ -26,21 +26,12 @@ impl TunInterface {
}
}
// #[instrument]
pub async fn recv(&mut self, buf: &mut [u8]) -> io::Result<usize> {
#[instrument]
pub async fn recv(&self, buf: &mut [u8]) -> io::Result<usize> {
loop {
// tracing::debug!("TunInterface receiving...");
let mut guard = self.inner.readable_mut().await?;
// tracing::debug!("Got! readable_mut");
match guard.try_io(|inner| {
let raw_ref = (*inner).get_mut();
let recved = raw_ref.recv(buf);
recved
}) {
Ok(result) => {
tracing::debug!("HORRAY");
return result
}
let mut guard = self.inner.readable().await?;
match guard.try_io(|inner| inner.get_ref().recv(buf)) {
Ok(result) => return result,
Err(_would_block) => {
tracing::debug!("WouldBlock");
continue
@ -48,13 +39,4 @@ impl TunInterface {
}
}
}
#[instrument]
pub async fn try_recv(&mut self, buf: &mut [u8]) -> io::Result<usize> {
let mut guard = self.inner.readable_mut().await?;
match guard.try_io(|inner| (*inner).get_mut().recv(buf)) {
Ok(result) => Ok(result.unwrap_or_default()),
Err(_would_block) => Err(io::Error::new(io::ErrorKind::WouldBlock, "WouldBlock")),
}
}
}

View file

@ -1,5 +1,6 @@
use std::{
io::{Error, Read},
mem::MaybeUninit,
os::fd::{AsRawFd, FromRawFd, IntoRawFd, RawFd},
};
@ -38,14 +39,19 @@ impl IntoRawFd for TunInterface {
}
}
unsafe fn assume_init(buf: &[MaybeUninit<u8>]) -> &[u8] {
&*(buf as *const [MaybeUninit<u8>] as *const [u8])
}
impl TunInterface {
#[throws]
#[instrument]
pub fn recv(&mut self, buf: &mut [u8]) -> usize {
// there might be a more efficient way to implement this
let tmp_buf = &mut [0u8; 1500];
let len = self.socket.read(tmp_buf)?;
buf[..len - 4].copy_from_slice(&tmp_buf[4..len]);
pub fn recv(&self, buf: &mut [u8]) -> usize {
// Use IoVec to read directly into target buffer
let mut tmp_buf = [MaybeUninit::uninit(); 1500];
let len = self.socket.recv(&mut tmp_buf)?;
let result_buf = unsafe { assume_init(&tmp_buf[4..len]) };
buf[..len - 4].copy_from_slice(&result_buf);
len - 4
}