Update locking to be interior to PeerPcb

This commit is contained in:
Conrad Kramer 2023-12-09 19:47:41 -08:00 committed by Jett Chen
parent 261f24d9ef
commit 3e5a01ffbe
6 changed files with 57 additions and 77 deletions

View file

@ -54,11 +54,12 @@ impl DaemonInstance {
warn!("Got start, but tun interface already up."); warn!("Got start, but tun interface already up.");
} }
RunState::Idle => { RunState::Idle => {
debug!("Creating new TunInterface"); let raw = tun::TunInterface::retrieve().unwrap();
let retrieved = TunInterface::new(tun::TunInterface::retrieve().unwrap())?; debug!("TunInterface retrieved: {:?}", raw.name()?);
let retrieved = TunInterface::new(raw)?;
let tun_if = Arc::new(RwLock::new(retrieved)); let tun_if = Arc::new(RwLock::new(retrieved));
// let tun_if = Arc::new(RwLock::new(TunInterface::new(st.tun.open()?)?)); // let tun_if = Arc::new(RwLock::new(TunInterface::new(st.tun.open()?)?));
debug!("TunInterface created: {:?}", tun_if);
debug!("Setting tun_interface"); debug!("Setting tun_interface");
self.tun_interface = Some(tun_if.clone()); self.tun_interface = Some(tun_if.clone());

View file

@ -62,7 +62,8 @@ pub async fn daemon_main() -> Result<()> {
allowed_ips: vec![IpNetwork::V4(Ipv4Network::DEFAULT_ROUTE)], allowed_ips: vec![IpNetwork::V4(Ipv4Network::DEFAULT_ROUTE)],
}])?; }])?;
let mut inst = DaemonInstance::new(commands_rx, response_tx, Arc::new(RwLock::new(iface))); let mut inst: DaemonInstance =
DaemonInstance::new(commands_rx, response_tx, Arc::new(RwLock::new(iface)));
tracing::info!("Starting daemon jobs..."); tracing::info!("Starting daemon jobs...");

View file

@ -29,7 +29,7 @@ impl PacketInterface for tun::tokio::TunInterface {
} }
struct IndexedPcbs { struct IndexedPcbs {
pcbs: Vec<Arc<RwLock<PeerPcb>>>, pcbs: Vec<Arc<PeerPcb>>,
allowed_ips: IpNetworkTable<usize>, allowed_ips: IpNetworkTable<usize>,
} }
@ -46,7 +46,7 @@ impl IndexedPcbs {
for allowed_ip in pcb.allowed_ips.iter() { for allowed_ip in pcb.allowed_ips.iter() {
self.allowed_ips.insert(allowed_ip.clone(), idx); self.allowed_ips.insert(allowed_ip.clone(), idx);
} }
self.pcbs.insert(idx, Arc::new(RwLock::new(pcb))); self.pcbs.insert(idx, Arc::new(pcb));
} }
pub fn find(&self, addr: IpAddr) -> Option<usize> { pub fn find(&self, addr: IpAddr) -> Option<usize> {
@ -55,7 +55,7 @@ impl IndexedPcbs {
} }
pub async fn connect(&self, idx: usize, handle: JoinHandle<()>) { pub async fn connect(&self, idx: usize, handle: JoinHandle<()>) {
self.pcbs[idx].write().await.handle = Some(handle); self.pcbs[idx].handle.write().await.replace(handle);
} }
} }
@ -106,7 +106,7 @@ impl Interface {
let src = { let src = {
let src = match timeout( let src = match timeout(
Duration::from_millis(10), Duration::from_millis(10),
tun.write().await.recv(&mut buf[..]), tun.read().await.recv(&mut buf[..]),
) )
.await .await
{ {
@ -138,9 +138,10 @@ impl Interface {
tracing::debug!("found peer:{}", idx); tracing::debug!("found peer:{}", idx);
match pcbs.pcbs[idx].read().await.send(src).await { match pcbs.pcbs[idx].send(src).await {
Ok(..) => { Ok(..) => {
tracing::debug!("sent packet to peer {}", dst_addr); let addr = pcbs.pcbs[idx].endpoint;
tracing::debug!("sent packet to peer {}", addr);
} }
Err(e) => { Err(e) => {
log::error!("failed to send packet {}", e); log::error!("failed to send packet {}", e);
@ -166,14 +167,11 @@ impl Interface {
let pcb = pcbs.pcbs[i].clone(); let pcb = pcbs.pcbs[i].clone();
let tun = tun.clone(); let tun = tun.clone();
let tsk = async move { let tsk = async move {
{ if let Err(e) = pcb.open_if_closed().await {
let r1 = pcb.write().await.open_if_closed().await;
if let Err(e) = r1 {
log::error!("failed to open pcb: {}", e); log::error!("failed to open pcb: {}", e);
return return
} }
} let r2 = pcb.run(tun).await;
let r2 = pcb.read().await.run(tun).await;
if let Err(e) = r2 { if let Err(e) = r2 {
log::error!("failed to run pcb: {}", e); log::error!("failed to run pcb: {}", e);
return return

View file

@ -1,4 +1,8 @@
use std::{net::SocketAddr, sync::Arc}; use std::{
cell::{Cell, RefCell},
net::SocketAddr,
sync::Arc,
};
use anyhow::{anyhow, Error}; use anyhow::{anyhow, Error};
use fehler::throws; use fehler::throws;
@ -8,7 +12,6 @@ use tokio::{net::UdpSocket, sync::RwLock, task::JoinHandle};
use tun::tokio::TunInterface; use tun::tokio::TunInterface;
use super::{ use super::{
iface::PacketInterface,
noise::{TunnResult, Tunnel}, noise::{TunnResult, Tunnel},
Peer, Peer,
}; };
@ -17,8 +20,8 @@ use super::{
pub struct PeerPcb { pub struct PeerPcb {
pub endpoint: SocketAddr, pub endpoint: SocketAddr,
pub allowed_ips: Vec<IpNetwork>, pub allowed_ips: Vec<IpNetwork>,
pub handle: Option<JoinHandle<()>>, pub handle: RwLock<Option<JoinHandle<()>>>,
socket: Option<UdpSocket>, socket: RwLock<Option<UdpSocket>>,
tunnel: RwLock<Tunnel>, tunnel: RwLock<Tunnel>,
} }
@ -36,46 +39,35 @@ impl PeerPcb {
) )
.map_err(|s| anyhow::anyhow!("{}", s))?, .map_err(|s| anyhow::anyhow!("{}", s))?,
); );
Self { Self {
endpoint: peer.endpoint, endpoint: peer.endpoint,
allowed_ips: peer.allowed_ips, allowed_ips: peer.allowed_ips,
handle: None, handle: RwLock::new(None),
socket: None, socket: RwLock::new(None),
tunnel, tunnel,
} }
} }
pub async fn open_if_closed(&mut self) -> Result<(), Error> { pub async fn open_if_closed(&self) -> Result<(), Error> {
if self.socket.is_none() { if self.socket.read().await.is_none() {
let socket = UdpSocket::bind("0.0.0.0:0").await?; let socket = UdpSocket::bind("0.0.0.0:0").await?;
socket.connect(self.endpoint).await?; socket.connect(self.endpoint).await?;
self.socket = Some(socket); self.socket.write().await.replace(socket);
} }
Ok(()) Ok(())
} }
pub async fn run(&self, tun_interface: Arc<RwLock<TunInterface>>) -> Result<(), Error> { pub async fn run(&self, tun_interface: Arc<RwLock<TunInterface>>) -> Result<(), Error> {
let mut buf = [0u8; 3000];
tracing::debug!("starting read loop for pcb...");
loop {
tracing::debug!("waiting for packet");
let len = self.recv(&mut buf, tun_interface.clone()).await?;
tracing::debug!("received {} bytes", len);
}
}
pub async fn recv(
&self,
buf: &mut [u8],
tun_interface: Arc<RwLock<TunInterface>>,
) -> Result<usize, Error> {
tracing::debug!("starting read loop for pcb... for {:?}", &self); tracing::debug!("starting read loop for pcb... for {:?}", &self);
let rid: i32 = random(); let rid: i32 = random();
let mut buf: [u8; 3000] = [0u8; 3000];
tracing::debug!("start read loop {}", rid); tracing::debug!("start read loop {}", rid);
loop { loop {
tracing::debug!("{}: waiting for packet", rid); tracing::debug!("{}: waiting for packet", rid);
let Some(socket) = &self.socket else { continue }; let guard = self.socket.read().await;
let Some(socket) = guard.as_ref() else {
continue
};
let mut res_buf = [0; 1500]; let mut res_buf = [0; 1500];
// tracing::debug!("{} : waiting for readability on {:?}", rid, socket); // tracing::debug!("{} : waiting for readability on {:?}", rid, socket);
let len = match socket.recv(&mut res_buf).await { let len = match socket.recv(&mut res_buf).await {
@ -102,6 +94,7 @@ impl PeerPcb {
} }
TunnResult::WriteToNetwork(packet) => { TunnResult::WriteToNetwork(packet) => {
tracing::debug!("WriteToNetwork: {:?}", packet); tracing::debug!("WriteToNetwork: {:?}", packet);
self.open_if_closed().await?;
socket.send(packet).await?; socket.send(packet).await?;
tracing::debug!("WriteToNetwork done"); tracing::debug!("WriteToNetwork done");
res_dat = &[]; res_dat = &[];
@ -119,15 +112,9 @@ impl PeerPcb {
} }
} }
} }
return Ok(len)
} }
} }
pub async fn socket(&mut self) -> Result<&UdpSocket, Error> {
self.open_if_closed().await?;
Ok(self.socket.as_ref().expect("socket was just opened"))
}
pub async fn send(&self, src: &[u8]) -> Result<(), Error> { pub async fn send(&self, src: &[u8]) -> Result<(), Error> {
let mut dst_buf = [0u8; 3000]; let mut dst_buf = [0u8; 3000];
match self.tunnel.write().await.encapsulate(src, &mut dst_buf[..]) { match self.tunnel.write().await.encapsulate(src, &mut dst_buf[..]) {
@ -136,7 +123,12 @@ impl PeerPcb {
tracing::error!(message = "Encapsulate error", error = ?e) tracing::error!(message = "Encapsulate error", error = ?e)
} }
TunnResult::WriteToNetwork(packet) => { TunnResult::WriteToNetwork(packet) => {
let socket = &self.socket.as_ref().ok_or(anyhow!("socket not open"))?; self.open_if_closed().await?;
let handle = self.socket.read().await;
let Some(socket) = handle.as_ref() else {
tracing::error!("No socket for peer");
return Ok(())
};
tracing::debug!("Our Encapsulated packet: {:?}", packet); tracing::debug!("Our Encapsulated packet: {:?}", packet);
socket.send(packet).await?; socket.send(packet).await?;
} }

View file

@ -26,21 +26,12 @@ impl TunInterface {
} }
} }
// #[instrument] #[instrument]
pub async fn recv(&mut self, buf: &mut [u8]) -> io::Result<usize> { pub async fn recv(&self, buf: &mut [u8]) -> io::Result<usize> {
loop { loop {
// tracing::debug!("TunInterface receiving..."); let mut guard = self.inner.readable().await?;
let mut guard = self.inner.readable_mut().await?; match guard.try_io(|inner| inner.get_ref().recv(buf)) {
// tracing::debug!("Got! readable_mut"); Ok(result) => return result,
match guard.try_io(|inner| {
let raw_ref = (*inner).get_mut();
let recved = raw_ref.recv(buf);
recved
}) {
Ok(result) => {
tracing::debug!("HORRAY");
return result
}
Err(_would_block) => { Err(_would_block) => {
tracing::debug!("WouldBlock"); tracing::debug!("WouldBlock");
continue continue
@ -48,13 +39,4 @@ impl TunInterface {
} }
} }
} }
#[instrument]
pub async fn try_recv(&mut self, buf: &mut [u8]) -> io::Result<usize> {
let mut guard = self.inner.readable_mut().await?;
match guard.try_io(|inner| (*inner).get_mut().recv(buf)) {
Ok(result) => Ok(result.unwrap_or_default()),
Err(_would_block) => Err(io::Error::new(io::ErrorKind::WouldBlock, "WouldBlock")),
}
}
} }

View file

@ -1,5 +1,6 @@
use std::{ use std::{
io::{Error, Read}, io::{Error, Read},
mem::MaybeUninit,
os::fd::{AsRawFd, FromRawFd, IntoRawFd, RawFd}, os::fd::{AsRawFd, FromRawFd, IntoRawFd, RawFd},
}; };
@ -38,14 +39,19 @@ impl IntoRawFd for TunInterface {
} }
} }
unsafe fn assume_init(buf: &[MaybeUninit<u8>]) -> &[u8] {
&*(buf as *const [MaybeUninit<u8>] as *const [u8])
}
impl TunInterface { impl TunInterface {
#[throws] #[throws]
#[instrument] #[instrument]
pub fn recv(&mut self, buf: &mut [u8]) -> usize { pub fn recv(&self, buf: &mut [u8]) -> usize {
// there might be a more efficient way to implement this // Use IoVec to read directly into target buffer
let tmp_buf = &mut [0u8; 1500]; let mut tmp_buf = [MaybeUninit::uninit(); 1500];
let len = self.socket.read(tmp_buf)?; let len = self.socket.recv(&mut tmp_buf)?;
buf[..len - 4].copy_from_slice(&tmp_buf[4..len]); let result_buf = unsafe { assume_init(&tmp_buf[4..len]) };
buf[..len - 4].copy_from_slice(&result_buf);
len - 4 len - 4
} }