Update locking to be interior to PeerPcb
This commit is contained in:
parent
261f24d9ef
commit
3e5a01ffbe
6 changed files with 57 additions and 77 deletions
|
|
@ -54,11 +54,12 @@ impl DaemonInstance {
|
||||||
warn!("Got start, but tun interface already up.");
|
warn!("Got start, but tun interface already up.");
|
||||||
}
|
}
|
||||||
RunState::Idle => {
|
RunState::Idle => {
|
||||||
debug!("Creating new TunInterface");
|
let raw = tun::TunInterface::retrieve().unwrap();
|
||||||
let retrieved = TunInterface::new(tun::TunInterface::retrieve().unwrap())?;
|
debug!("TunInterface retrieved: {:?}", raw.name()?);
|
||||||
|
|
||||||
|
let retrieved = TunInterface::new(raw)?;
|
||||||
let tun_if = Arc::new(RwLock::new(retrieved));
|
let tun_if = Arc::new(RwLock::new(retrieved));
|
||||||
// let tun_if = Arc::new(RwLock::new(TunInterface::new(st.tun.open()?)?));
|
// let tun_if = Arc::new(RwLock::new(TunInterface::new(st.tun.open()?)?));
|
||||||
debug!("TunInterface created: {:?}", tun_if);
|
|
||||||
|
|
||||||
debug!("Setting tun_interface");
|
debug!("Setting tun_interface");
|
||||||
self.tun_interface = Some(tun_if.clone());
|
self.tun_interface = Some(tun_if.clone());
|
||||||
|
|
|
||||||
|
|
@ -62,7 +62,8 @@ pub async fn daemon_main() -> Result<()> {
|
||||||
allowed_ips: vec![IpNetwork::V4(Ipv4Network::DEFAULT_ROUTE)],
|
allowed_ips: vec![IpNetwork::V4(Ipv4Network::DEFAULT_ROUTE)],
|
||||||
}])?;
|
}])?;
|
||||||
|
|
||||||
let mut inst = DaemonInstance::new(commands_rx, response_tx, Arc::new(RwLock::new(iface)));
|
let mut inst: DaemonInstance =
|
||||||
|
DaemonInstance::new(commands_rx, response_tx, Arc::new(RwLock::new(iface)));
|
||||||
|
|
||||||
tracing::info!("Starting daemon jobs...");
|
tracing::info!("Starting daemon jobs...");
|
||||||
|
|
||||||
|
|
|
||||||
|
|
@ -29,7 +29,7 @@ impl PacketInterface for tun::tokio::TunInterface {
|
||||||
}
|
}
|
||||||
|
|
||||||
struct IndexedPcbs {
|
struct IndexedPcbs {
|
||||||
pcbs: Vec<Arc<RwLock<PeerPcb>>>,
|
pcbs: Vec<Arc<PeerPcb>>,
|
||||||
allowed_ips: IpNetworkTable<usize>,
|
allowed_ips: IpNetworkTable<usize>,
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
@ -46,7 +46,7 @@ impl IndexedPcbs {
|
||||||
for allowed_ip in pcb.allowed_ips.iter() {
|
for allowed_ip in pcb.allowed_ips.iter() {
|
||||||
self.allowed_ips.insert(allowed_ip.clone(), idx);
|
self.allowed_ips.insert(allowed_ip.clone(), idx);
|
||||||
}
|
}
|
||||||
self.pcbs.insert(idx, Arc::new(RwLock::new(pcb)));
|
self.pcbs.insert(idx, Arc::new(pcb));
|
||||||
}
|
}
|
||||||
|
|
||||||
pub fn find(&self, addr: IpAddr) -> Option<usize> {
|
pub fn find(&self, addr: IpAddr) -> Option<usize> {
|
||||||
|
|
@ -55,7 +55,7 @@ impl IndexedPcbs {
|
||||||
}
|
}
|
||||||
|
|
||||||
pub async fn connect(&self, idx: usize, handle: JoinHandle<()>) {
|
pub async fn connect(&self, idx: usize, handle: JoinHandle<()>) {
|
||||||
self.pcbs[idx].write().await.handle = Some(handle);
|
self.pcbs[idx].handle.write().await.replace(handle);
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
@ -106,7 +106,7 @@ impl Interface {
|
||||||
let src = {
|
let src = {
|
||||||
let src = match timeout(
|
let src = match timeout(
|
||||||
Duration::from_millis(10),
|
Duration::from_millis(10),
|
||||||
tun.write().await.recv(&mut buf[..]),
|
tun.read().await.recv(&mut buf[..]),
|
||||||
)
|
)
|
||||||
.await
|
.await
|
||||||
{
|
{
|
||||||
|
|
@ -138,9 +138,10 @@ impl Interface {
|
||||||
|
|
||||||
tracing::debug!("found peer:{}", idx);
|
tracing::debug!("found peer:{}", idx);
|
||||||
|
|
||||||
match pcbs.pcbs[idx].read().await.send(src).await {
|
match pcbs.pcbs[idx].send(src).await {
|
||||||
Ok(..) => {
|
Ok(..) => {
|
||||||
tracing::debug!("sent packet to peer {}", dst_addr);
|
let addr = pcbs.pcbs[idx].endpoint;
|
||||||
|
tracing::debug!("sent packet to peer {}", addr);
|
||||||
}
|
}
|
||||||
Err(e) => {
|
Err(e) => {
|
||||||
log::error!("failed to send packet {}", e);
|
log::error!("failed to send packet {}", e);
|
||||||
|
|
@ -166,14 +167,11 @@ impl Interface {
|
||||||
let pcb = pcbs.pcbs[i].clone();
|
let pcb = pcbs.pcbs[i].clone();
|
||||||
let tun = tun.clone();
|
let tun = tun.clone();
|
||||||
let tsk = async move {
|
let tsk = async move {
|
||||||
{
|
if let Err(e) = pcb.open_if_closed().await {
|
||||||
let r1 = pcb.write().await.open_if_closed().await;
|
|
||||||
if let Err(e) = r1 {
|
|
||||||
log::error!("failed to open pcb: {}", e);
|
log::error!("failed to open pcb: {}", e);
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
}
|
let r2 = pcb.run(tun).await;
|
||||||
let r2 = pcb.read().await.run(tun).await;
|
|
||||||
if let Err(e) = r2 {
|
if let Err(e) = r2 {
|
||||||
log::error!("failed to run pcb: {}", e);
|
log::error!("failed to run pcb: {}", e);
|
||||||
return
|
return
|
||||||
|
|
|
||||||
|
|
@ -1,4 +1,8 @@
|
||||||
use std::{net::SocketAddr, sync::Arc};
|
use std::{
|
||||||
|
cell::{Cell, RefCell},
|
||||||
|
net::SocketAddr,
|
||||||
|
sync::Arc,
|
||||||
|
};
|
||||||
|
|
||||||
use anyhow::{anyhow, Error};
|
use anyhow::{anyhow, Error};
|
||||||
use fehler::throws;
|
use fehler::throws;
|
||||||
|
|
@ -8,7 +12,6 @@ use tokio::{net::UdpSocket, sync::RwLock, task::JoinHandle};
|
||||||
use tun::tokio::TunInterface;
|
use tun::tokio::TunInterface;
|
||||||
|
|
||||||
use super::{
|
use super::{
|
||||||
iface::PacketInterface,
|
|
||||||
noise::{TunnResult, Tunnel},
|
noise::{TunnResult, Tunnel},
|
||||||
Peer,
|
Peer,
|
||||||
};
|
};
|
||||||
|
|
@ -17,8 +20,8 @@ use super::{
|
||||||
pub struct PeerPcb {
|
pub struct PeerPcb {
|
||||||
pub endpoint: SocketAddr,
|
pub endpoint: SocketAddr,
|
||||||
pub allowed_ips: Vec<IpNetwork>,
|
pub allowed_ips: Vec<IpNetwork>,
|
||||||
pub handle: Option<JoinHandle<()>>,
|
pub handle: RwLock<Option<JoinHandle<()>>>,
|
||||||
socket: Option<UdpSocket>,
|
socket: RwLock<Option<UdpSocket>>,
|
||||||
tunnel: RwLock<Tunnel>,
|
tunnel: RwLock<Tunnel>,
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
@ -36,46 +39,35 @@ impl PeerPcb {
|
||||||
)
|
)
|
||||||
.map_err(|s| anyhow::anyhow!("{}", s))?,
|
.map_err(|s| anyhow::anyhow!("{}", s))?,
|
||||||
);
|
);
|
||||||
|
|
||||||
Self {
|
Self {
|
||||||
endpoint: peer.endpoint,
|
endpoint: peer.endpoint,
|
||||||
allowed_ips: peer.allowed_ips,
|
allowed_ips: peer.allowed_ips,
|
||||||
handle: None,
|
handle: RwLock::new(None),
|
||||||
socket: None,
|
socket: RwLock::new(None),
|
||||||
tunnel,
|
tunnel,
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
pub async fn open_if_closed(&mut self) -> Result<(), Error> {
|
pub async fn open_if_closed(&self) -> Result<(), Error> {
|
||||||
if self.socket.is_none() {
|
if self.socket.read().await.is_none() {
|
||||||
let socket = UdpSocket::bind("0.0.0.0:0").await?;
|
let socket = UdpSocket::bind("0.0.0.0:0").await?;
|
||||||
socket.connect(self.endpoint).await?;
|
socket.connect(self.endpoint).await?;
|
||||||
self.socket = Some(socket);
|
self.socket.write().await.replace(socket);
|
||||||
}
|
}
|
||||||
Ok(())
|
Ok(())
|
||||||
}
|
}
|
||||||
|
|
||||||
pub async fn run(&self, tun_interface: Arc<RwLock<TunInterface>>) -> Result<(), Error> {
|
pub async fn run(&self, tun_interface: Arc<RwLock<TunInterface>>) -> Result<(), Error> {
|
||||||
let mut buf = [0u8; 3000];
|
|
||||||
tracing::debug!("starting read loop for pcb...");
|
|
||||||
loop {
|
|
||||||
tracing::debug!("waiting for packet");
|
|
||||||
let len = self.recv(&mut buf, tun_interface.clone()).await?;
|
|
||||||
tracing::debug!("received {} bytes", len);
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
pub async fn recv(
|
|
||||||
&self,
|
|
||||||
buf: &mut [u8],
|
|
||||||
tun_interface: Arc<RwLock<TunInterface>>,
|
|
||||||
) -> Result<usize, Error> {
|
|
||||||
tracing::debug!("starting read loop for pcb... for {:?}", &self);
|
tracing::debug!("starting read loop for pcb... for {:?}", &self);
|
||||||
let rid: i32 = random();
|
let rid: i32 = random();
|
||||||
|
let mut buf: [u8; 3000] = [0u8; 3000];
|
||||||
tracing::debug!("start read loop {}", rid);
|
tracing::debug!("start read loop {}", rid);
|
||||||
loop {
|
loop {
|
||||||
tracing::debug!("{}: waiting for packet", rid);
|
tracing::debug!("{}: waiting for packet", rid);
|
||||||
let Some(socket) = &self.socket else { continue };
|
let guard = self.socket.read().await;
|
||||||
|
let Some(socket) = guard.as_ref() else {
|
||||||
|
continue
|
||||||
|
};
|
||||||
let mut res_buf = [0; 1500];
|
let mut res_buf = [0; 1500];
|
||||||
// tracing::debug!("{} : waiting for readability on {:?}", rid, socket);
|
// tracing::debug!("{} : waiting for readability on {:?}", rid, socket);
|
||||||
let len = match socket.recv(&mut res_buf).await {
|
let len = match socket.recv(&mut res_buf).await {
|
||||||
|
|
@ -102,6 +94,7 @@ impl PeerPcb {
|
||||||
}
|
}
|
||||||
TunnResult::WriteToNetwork(packet) => {
|
TunnResult::WriteToNetwork(packet) => {
|
||||||
tracing::debug!("WriteToNetwork: {:?}", packet);
|
tracing::debug!("WriteToNetwork: {:?}", packet);
|
||||||
|
self.open_if_closed().await?;
|
||||||
socket.send(packet).await?;
|
socket.send(packet).await?;
|
||||||
tracing::debug!("WriteToNetwork done");
|
tracing::debug!("WriteToNetwork done");
|
||||||
res_dat = &[];
|
res_dat = &[];
|
||||||
|
|
@ -119,15 +112,9 @@ impl PeerPcb {
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
return Ok(len)
|
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
pub async fn socket(&mut self) -> Result<&UdpSocket, Error> {
|
|
||||||
self.open_if_closed().await?;
|
|
||||||
Ok(self.socket.as_ref().expect("socket was just opened"))
|
|
||||||
}
|
|
||||||
|
|
||||||
pub async fn send(&self, src: &[u8]) -> Result<(), Error> {
|
pub async fn send(&self, src: &[u8]) -> Result<(), Error> {
|
||||||
let mut dst_buf = [0u8; 3000];
|
let mut dst_buf = [0u8; 3000];
|
||||||
match self.tunnel.write().await.encapsulate(src, &mut dst_buf[..]) {
|
match self.tunnel.write().await.encapsulate(src, &mut dst_buf[..]) {
|
||||||
|
|
@ -136,7 +123,12 @@ impl PeerPcb {
|
||||||
tracing::error!(message = "Encapsulate error", error = ?e)
|
tracing::error!(message = "Encapsulate error", error = ?e)
|
||||||
}
|
}
|
||||||
TunnResult::WriteToNetwork(packet) => {
|
TunnResult::WriteToNetwork(packet) => {
|
||||||
let socket = &self.socket.as_ref().ok_or(anyhow!("socket not open"))?;
|
self.open_if_closed().await?;
|
||||||
|
let handle = self.socket.read().await;
|
||||||
|
let Some(socket) = handle.as_ref() else {
|
||||||
|
tracing::error!("No socket for peer");
|
||||||
|
return Ok(())
|
||||||
|
};
|
||||||
tracing::debug!("Our Encapsulated packet: {:?}", packet);
|
tracing::debug!("Our Encapsulated packet: {:?}", packet);
|
||||||
socket.send(packet).await?;
|
socket.send(packet).await?;
|
||||||
}
|
}
|
||||||
|
|
|
||||||
|
|
@ -26,21 +26,12 @@ impl TunInterface {
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
// #[instrument]
|
#[instrument]
|
||||||
pub async fn recv(&mut self, buf: &mut [u8]) -> io::Result<usize> {
|
pub async fn recv(&self, buf: &mut [u8]) -> io::Result<usize> {
|
||||||
loop {
|
loop {
|
||||||
// tracing::debug!("TunInterface receiving...");
|
let mut guard = self.inner.readable().await?;
|
||||||
let mut guard = self.inner.readable_mut().await?;
|
match guard.try_io(|inner| inner.get_ref().recv(buf)) {
|
||||||
// tracing::debug!("Got! readable_mut");
|
Ok(result) => return result,
|
||||||
match guard.try_io(|inner| {
|
|
||||||
let raw_ref = (*inner).get_mut();
|
|
||||||
let recved = raw_ref.recv(buf);
|
|
||||||
recved
|
|
||||||
}) {
|
|
||||||
Ok(result) => {
|
|
||||||
tracing::debug!("HORRAY");
|
|
||||||
return result
|
|
||||||
}
|
|
||||||
Err(_would_block) => {
|
Err(_would_block) => {
|
||||||
tracing::debug!("WouldBlock");
|
tracing::debug!("WouldBlock");
|
||||||
continue
|
continue
|
||||||
|
|
@ -48,13 +39,4 @@ impl TunInterface {
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
#[instrument]
|
|
||||||
pub async fn try_recv(&mut self, buf: &mut [u8]) -> io::Result<usize> {
|
|
||||||
let mut guard = self.inner.readable_mut().await?;
|
|
||||||
match guard.try_io(|inner| (*inner).get_mut().recv(buf)) {
|
|
||||||
Ok(result) => Ok(result.unwrap_or_default()),
|
|
||||||
Err(_would_block) => Err(io::Error::new(io::ErrorKind::WouldBlock, "WouldBlock")),
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
|
|
|
||||||
|
|
@ -1,5 +1,6 @@
|
||||||
use std::{
|
use std::{
|
||||||
io::{Error, Read},
|
io::{Error, Read},
|
||||||
|
mem::MaybeUninit,
|
||||||
os::fd::{AsRawFd, FromRawFd, IntoRawFd, RawFd},
|
os::fd::{AsRawFd, FromRawFd, IntoRawFd, RawFd},
|
||||||
};
|
};
|
||||||
|
|
||||||
|
|
@ -38,14 +39,19 @@ impl IntoRawFd for TunInterface {
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
unsafe fn assume_init(buf: &[MaybeUninit<u8>]) -> &[u8] {
|
||||||
|
&*(buf as *const [MaybeUninit<u8>] as *const [u8])
|
||||||
|
}
|
||||||
|
|
||||||
impl TunInterface {
|
impl TunInterface {
|
||||||
#[throws]
|
#[throws]
|
||||||
#[instrument]
|
#[instrument]
|
||||||
pub fn recv(&mut self, buf: &mut [u8]) -> usize {
|
pub fn recv(&self, buf: &mut [u8]) -> usize {
|
||||||
// there might be a more efficient way to implement this
|
// Use IoVec to read directly into target buffer
|
||||||
let tmp_buf = &mut [0u8; 1500];
|
let mut tmp_buf = [MaybeUninit::uninit(); 1500];
|
||||||
let len = self.socket.read(tmp_buf)?;
|
let len = self.socket.recv(&mut tmp_buf)?;
|
||||||
buf[..len - 4].copy_from_slice(&tmp_buf[4..len]);
|
let result_buf = unsafe { assume_init(&tmp_buf[4..len]) };
|
||||||
|
buf[..len - 4].copy_from_slice(&result_buf);
|
||||||
len - 4
|
len - 4
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
|
||||||
Loading…
Add table
Add a link
Reference in a new issue