Drive daemon tunnels from stored networks

This commit is contained in:
Conrad Kramer 2026-03-30 19:01:58 -07:00
parent 3fb0269d7c
commit 450e9c6fcd
4 changed files with 726 additions and 139 deletions

View file

@ -1,48 +1,27 @@
use std::{ use std::{
ops::Deref,
path::{Path, PathBuf}, path::{Path, PathBuf},
sync::Arc, sync::Arc,
time::Duration,
}; };
use anyhow::Result; use anyhow::Result;
use rusqlite::Connection; use rusqlite::Connection;
use tokio::sync::{mpsc, watch, Notify, RwLock}; use tokio::sync::{mpsc, watch, RwLock};
use tokio_stream::wrappers::ReceiverStream; use tokio_stream::wrappers::ReceiverStream;
use tonic::{Request, Response, Status as RspStatus}; use tonic::{Request, Response, Status as RspStatus};
use tracing::{debug, info, warn}; use tracing::warn;
use tun::{tokio::TunInterface, TunOptions}; use tun::tokio::TunInterface;
use super::rpc::grpc_defs::{ use super::{
networks_server::Networks, rpc::grpc_defs::{
tunnel_server::Tunnel, networks_server::Networks, tunnel_server::Tunnel, Empty, Network, NetworkDeleteRequest,
Empty, NetworkListResponse, NetworkReorderRequest, State as RPCTunnelState,
Network, TunnelConfigurationResponse, TunnelStatusResponse,
NetworkDeleteRequest, },
NetworkListResponse, runtime::{ActiveTunnel, ResolvedTunnel},
NetworkReorderRequest,
State as RPCTunnelState,
TunnelConfigurationResponse,
TunnelStatusResponse,
}; };
use crate::{ use crate::{
daemon::rpc::{ daemon::rpc::ServerConfig,
DaemonCommand, database::{add_network, delete_network, get_connection, list_networks, reorder_network},
DaemonNotification,
DaemonResponse,
DaemonResponseData,
ServerConfig,
ServerInfo,
},
database::{
add_network,
delete_network,
get_connection,
list_networks,
load_interface,
reorder_network,
},
wireguard::{Config, Interface},
}; };
#[derive(Debug, Clone)] #[derive(Debug, Clone)]
@ -52,10 +31,10 @@ enum RunState {
} }
impl RunState { impl RunState {
pub fn to_rpc(&self) -> RPCTunnelState { fn to_rpc(&self) -> RPCTunnelState {
match self { match self {
RunState::Running => RPCTunnelState::Running, Self::Running => RPCTunnelState::Running,
RunState::Idle => RPCTunnelState::Stopped, Self::Idle => RPCTunnelState::Stopped,
} }
} }
} }
@ -63,30 +42,24 @@ impl RunState {
#[derive(Clone)] #[derive(Clone)]
pub struct DaemonRPCServer { pub struct DaemonRPCServer {
tun_interface: Arc<RwLock<Option<TunInterface>>>, tun_interface: Arc<RwLock<Option<TunInterface>>>,
wg_interface: Arc<RwLock<Interface>>,
config: Arc<RwLock<Config>>,
db_path: Option<PathBuf>, db_path: Option<PathBuf>,
wg_state_chan: (watch::Sender<RunState>, watch::Receiver<RunState>), wg_state_chan: (watch::Sender<RunState>, watch::Receiver<RunState>),
network_update_chan: (watch::Sender<()>, watch::Receiver<()>), network_update_chan: (watch::Sender<()>, watch::Receiver<()>),
active_tunnel: Arc<RwLock<Option<ActiveTunnel>>>,
} }
impl DaemonRPCServer { impl DaemonRPCServer {
pub fn new( pub fn new(db_path: Option<&Path>) -> Result<Self> {
wg_interface: Arc<RwLock<Interface>>,
config: Arc<RwLock<Config>>,
db_path: Option<&Path>,
) -> Result<Self> {
Ok(Self { Ok(Self {
tun_interface: Arc::new(RwLock::new(None)), tun_interface: Arc::new(RwLock::new(None)),
wg_interface, db_path: db_path.map(Path::to_owned),
config,
db_path: db_path.map(|p| p.to_owned()),
wg_state_chan: watch::channel(RunState::Idle), wg_state_chan: watch::channel(RunState::Idle),
network_update_chan: watch::channel(()), network_update_chan: watch::channel(()),
active_tunnel: Arc::new(RwLock::new(None)),
}) })
} }
pub fn get_connection(&self) -> Result<Connection, RspStatus> { fn get_connection(&self) -> Result<Connection, RspStatus> {
get_connection(self.db_path.as_deref()).map_err(proc_err) get_connection(self.db_path.as_deref()).map_err(proc_err)
} }
@ -94,13 +67,71 @@ impl DaemonRPCServer {
self.wg_state_chan.0.send(state).map_err(proc_err) self.wg_state_chan.0.send(state).map_err(proc_err)
} }
async fn get_wg_state(&self) -> RunState {
self.wg_state_chan.1.borrow().to_owned()
}
async fn notify_network_update(&self) -> Result<(), RspStatus> { async fn notify_network_update(&self) -> Result<(), RspStatus> {
self.network_update_chan.0.send(()).map_err(proc_err) self.network_update_chan.0.send(()).map_err(proc_err)
} }
async fn resolve_tunnel(&self) -> Result<Option<ResolvedTunnel>, RspStatus> {
let conn = self.get_connection()?;
let networks = list_networks(&conn).map_err(proc_err)?;
ResolvedTunnel::from_networks(&networks).map_err(proc_err)
}
async fn current_tunnel_configuration(&self) -> Result<TunnelConfigurationResponse, RspStatus> {
match self.resolve_tunnel().await? {
Some(config) => {
let config = config.server_config().map_err(proc_err)?;
Ok(configuration_rsp(config))
}
None => Ok(empty_configuration_rsp()),
}
}
async fn stop_active_tunnel(&self) -> Result<bool, RspStatus> {
let current = { self.active_tunnel.write().await.take() };
let Some(current) = current else {
return Ok(false);
};
current
.shutdown(&self.tun_interface)
.await
.map_err(proc_err)?;
self.set_wg_state(RunState::Idle).await?;
Ok(true)
}
async fn replace_active_tunnel(&self, desired: ResolvedTunnel) -> Result<(), RspStatus> {
let _ = self.stop_active_tunnel().await?;
let active = desired
.start(self.tun_interface.clone())
.await
.map_err(proc_err)?;
self.active_tunnel.write().await.replace(active);
self.set_wg_state(RunState::Running).await?;
Ok(())
}
async fn reconcile_runtime(&self) -> Result<(), RspStatus> {
let desired = self.resolve_tunnel().await?;
let Some(desired) = desired else {
let _ = self.stop_active_tunnel().await?;
return Ok(());
};
let needs_restart = {
let guard = self.active_tunnel.read().await;
guard
.as_ref()
.map(|active| active.identity() != desired.identity())
.unwrap_or(false)
};
if needs_restart {
self.replace_active_tunnel(desired).await?;
}
Ok(())
}
} }
#[tonic::async_trait] #[tonic::async_trait]
@ -113,55 +144,49 @@ impl Tunnel for DaemonRPCServer {
_request: Request<Empty>, _request: Request<Empty>,
) -> Result<Response<Self::TunnelConfigurationStream>, RspStatus> { ) -> Result<Response<Self::TunnelConfigurationStream>, RspStatus> {
let (tx, rx) = mpsc::channel(10); let (tx, rx) = mpsc::channel(10);
let server = self.clone();
let mut sub = self.network_update_chan.1.clone();
tokio::spawn(async move { tokio::spawn(async move {
let serv_config = ServerConfig::default(); loop {
tx.send(Ok(TunnelConfigurationResponse { let response = server.current_tunnel_configuration().await;
mtu: serv_config.mtu.unwrap_or(1000), if tx.send(response).await.is_err() {
addresses: serv_config.address, break;
})) }
.await if sub.changed().await.is_err() {
break;
}
}
}); });
Ok(Response::new(ReceiverStream::new(rx))) Ok(Response::new(ReceiverStream::new(rx)))
} }
async fn tunnel_start(&self, _request: Request<Empty>) -> Result<Response<Empty>, RspStatus> { async fn tunnel_start(&self, _request: Request<Empty>) -> Result<Response<Empty>, RspStatus> {
let wg_state = self.get_wg_state().await; let desired = self
match wg_state { .resolve_tunnel()
RunState::Idle => { .await?
let tun_if = TunOptions::new().open()?; .ok_or_else(|| RspStatus::failed_precondition("no stored network configured"))?;
debug!("Setting tun on wg_interface"); let already_running = {
self.tun_interface.write().await.replace(tun_if); let guard = self.active_tunnel.read().await;
self.wg_interface guard
.write() .as_ref()
.await .map(|active| active.identity() == desired.identity())
.set_tun_ref(self.tun_interface.clone()) .unwrap_or(false)
.await; };
debug!("tun set on wg_interface");
debug!("Setting tun_interface");
debug!("tun_interface set: {:?}", self.tun_interface);
debug!("Cloning wg_interface");
let tmp_wg = self.wg_interface.clone();
let run_task = tokio::spawn(async move {
let twlock = tmp_wg.read().await;
twlock.run().await
});
self.set_wg_state(RunState::Running).await?;
}
RunState::Running => {
warn!("Got start, but tun interface already up.");
}
}
if already_running {
warn!("Got start, but active tunnel already matches desired network.");
return Ok(Response::new(Empty {})); return Ok(Response::new(Empty {}));
} }
self.replace_active_tunnel(desired).await?;
Ok(Response::new(Empty {}))
}
async fn tunnel_stop(&self, _request: Request<Empty>) -> Result<Response<Empty>, RspStatus> { async fn tunnel_stop(&self, _request: Request<Empty>) -> Result<Response<Empty>, RspStatus> {
self.wg_interface.write().await.remove_tun().await; let _ = self.stop_active_tunnel().await?;
self.set_wg_state(RunState::Idle).await?; Ok(Response::new(Empty {}))
return Ok(Response::new(Empty {}));
} }
async fn tunnel_status( async fn tunnel_status(
@ -172,13 +197,16 @@ impl Tunnel for DaemonRPCServer {
let mut state_rx = self.wg_state_chan.1.clone(); let mut state_rx = self.wg_state_chan.1.clone();
tokio::spawn(async move { tokio::spawn(async move {
let cur = state_rx.borrow_and_update().to_owned(); let cur = state_rx.borrow_and_update().to_owned();
tx.send(Ok(status_rsp(cur))).await; if tx.send(Ok(status_rsp(cur))).await.is_err() {
return;
}
loop { loop {
state_rx.changed().await.unwrap(); if state_rx.changed().await.is_err() {
break;
}
let cur = state_rx.borrow().to_owned(); let cur = state_rx.borrow().to_owned();
let res = tx.send(Ok(status_rsp(cur))).await; if tx.send(Ok(status_rsp(cur))).await.is_err() {
if res.is_err() {
eprintln!("Tunnel status channel closed");
break; break;
} }
} }
@ -196,6 +224,7 @@ impl Networks for DaemonRPCServer {
let network = request.into_inner(); let network = request.into_inner();
add_network(&conn, &network).map_err(proc_err)?; add_network(&conn, &network).map_err(proc_err)?;
self.notify_network_update().await?; self.notify_network_update().await?;
self.reconcile_runtime().await?;
Ok(Response::new(Empty {})) Ok(Response::new(Empty {}))
} }
@ -203,7 +232,6 @@ impl Networks for DaemonRPCServer {
&self, &self,
_request: Request<Empty>, _request: Request<Empty>,
) -> Result<Response<Self::NetworkListStream>, RspStatus> { ) -> Result<Response<Self::NetworkListStream>, RspStatus> {
debug!("Mock network_list called");
let (tx, rx) = mpsc::channel(10); let (tx, rx) = mpsc::channel(10);
let conn = self.get_connection()?; let conn = self.get_connection()?;
let mut sub = self.network_update_chan.1.clone(); let mut sub = self.network_update_chan.1.clone();
@ -212,12 +240,12 @@ impl Networks for DaemonRPCServer {
let networks = list_networks(&conn) let networks = list_networks(&conn)
.map(|res| NetworkListResponse { network: res }) .map(|res| NetworkListResponse { network: res })
.map_err(proc_err); .map_err(proc_err);
let res = tx.send(networks).await; if tx.send(networks).await.is_err() {
if res.is_err() { break;
eprintln!("Network list channel closed"); }
if sub.changed().await.is_err() {
break; break;
} }
sub.changed().await.unwrap();
} }
}); });
Ok(Response::new(ReceiverStream::new(rx))) Ok(Response::new(ReceiverStream::new(rx)))
@ -230,6 +258,7 @@ impl Networks for DaemonRPCServer {
let conn = self.get_connection()?; let conn = self.get_connection()?;
reorder_network(&conn, request.into_inner()).map_err(proc_err)?; reorder_network(&conn, request.into_inner()).map_err(proc_err)?;
self.notify_network_update().await?; self.notify_network_update().await?;
self.reconcile_runtime().await?;
Ok(Response::new(Empty {})) Ok(Response::new(Empty {}))
} }
@ -240,6 +269,7 @@ impl Networks for DaemonRPCServer {
let conn = self.get_connection()?; let conn = self.get_connection()?;
delete_network(&conn, request.into_inner()).map_err(proc_err)?; delete_network(&conn, request.into_inner()).map_err(proc_err)?;
self.notify_network_update().await?; self.notify_network_update().await?;
self.reconcile_runtime().await?;
Ok(Response::new(Empty {})) Ok(Response::new(Empty {}))
} }
} }
@ -248,6 +278,20 @@ fn proc_err(err: impl ToString) -> RspStatus {
RspStatus::internal(err.to_string()) RspStatus::internal(err.to_string())
} }
fn configuration_rsp(config: ServerConfig) -> TunnelConfigurationResponse {
TunnelConfigurationResponse {
mtu: config.mtu.unwrap_or(1000),
addresses: config.address,
}
}
fn empty_configuration_rsp() -> TunnelConfigurationResponse {
TunnelConfigurationResponse {
mtu: 1500,
addresses: Vec::new(),
}
}
fn status_rsp(state: RunState) -> TunnelStatusResponse { fn status_rsp(state: RunState) -> TunnelStatusResponse {
TunnelStatusResponse { TunnelStatusResponse {
state: state.to_rpc().into(), state: state.to_rpc().into(),

View file

@ -4,23 +4,20 @@ pub mod apple;
mod instance; mod instance;
mod net; mod net;
pub mod rpc; pub mod rpc;
mod runtime;
use anyhow::{Error as AhError, Result}; use anyhow::{Error as AhError, Result};
use instance::DaemonRPCServer; use instance::DaemonRPCServer;
pub use net::{get_socket_path, DaemonClient}; pub use net::{get_socket_path, DaemonClient};
pub use rpc::{DaemonCommand, DaemonResponseData, DaemonStartOptions}; pub use rpc::{DaemonCommand, DaemonResponseData, DaemonStartOptions};
use tokio::{ use tokio::{net::UnixListener, sync::Notify};
net::UnixListener,
sync::{Notify, RwLock},
};
use tokio_stream::wrappers::UnixListenerStream; use tokio_stream::wrappers::UnixListenerStream;
use tonic::transport::Server; use tonic::transport::Server;
use tracing::{error, info}; use tracing::info;
use crate::{ use crate::{
daemon::rpc::grpc_defs::{networks_server::NetworksServer, tunnel_server::TunnelServer}, daemon::rpc::grpc_defs::{networks_server::NetworksServer, tunnel_server::TunnelServer},
database::{get_connection, load_interface}, database::get_connection,
wireguard::Interface,
}; };
pub async fn daemon_main( pub async fn daemon_main(
@ -28,16 +25,8 @@ pub async fn daemon_main(
db_path: Option<&Path>, db_path: Option<&Path>,
notify_ready: Option<Arc<Notify>>, notify_ready: Option<Arc<Notify>>,
) -> Result<()> { ) -> Result<()> {
if let Some(n) = notify_ready { let _conn = get_connection(db_path)?;
n.notify_one() let burrow_server = DaemonRPCServer::new(db_path)?;
}
let conn = get_connection(db_path)?;
let config = load_interface(&conn, "1")?;
let burrow_server = DaemonRPCServer::new(
Arc::new(RwLock::new(config.clone().try_into()?)),
Arc::new(RwLock::new(config)),
db_path.clone(),
)?;
let spp = socket_path.clone(); let spp = socket_path.clone();
let tmp = get_socket_path(); let tmp = get_socket_path();
let sock_path = spp.unwrap_or(Path::new(tmp.as_str())); let sock_path = spp.unwrap_or(Path::new(tmp.as_str()));
@ -55,9 +44,213 @@ pub async fn daemon_main(
Ok::<(), AhError>(()) Ok::<(), AhError>(())
}); });
if let Some(n) = notify_ready {
n.notify_one();
}
info!("Starting daemon..."); info!("Starting daemon...");
tokio::try_join!(serve_job) tokio::try_join!(serve_job)
.map(|_| ()) .map(|_| ())
.map_err(|e| e.into()) .map_err(|e| e.into())
} }
#[cfg(test)]
mod tests {
use std::{
path::PathBuf,
time::{SystemTime, UNIX_EPOCH},
};
use anyhow::{anyhow, Result};
use iroh::PublicKey;
use serde_json::json;
use tokio::time::{timeout, Duration};
use super::*;
use crate::daemon::rpc::{
client::BurrowClient,
grpc_defs::{
Empty, Network, NetworkListResponse, NetworkReorderRequest, NetworkType,
TunnelConfigurationResponse,
},
};
#[tokio::test]
async fn daemon_tracks_network_priority_via_grpc() -> Result<()> {
let socket_path = temp_path("sock");
let db_path = temp_path("sqlite3");
let ready = Arc::new(Notify::new());
let daemon_ready = ready.clone();
let daemon_socket_path = socket_path.clone();
let daemon_db_path = db_path.clone();
let daemon_task = tokio::spawn(async move {
daemon_main(
Some(daemon_socket_path.as_path()),
Some(daemon_db_path.as_path()),
Some(daemon_ready),
)
.await
});
timeout(Duration::from_secs(5), ready.notified()).await?;
let mut client = timeout(
Duration::from_secs(5),
BurrowClient::from_uds_path(&socket_path),
)
.await??;
let mut config_stream = client
.tunnel_client
.tunnel_configuration(Empty {})
.await?
.into_inner();
let mut network_stream = client
.networks_client
.network_list(Empty {})
.await?
.into_inner();
let initial_config = next_configuration(&mut config_stream).await?;
assert!(initial_config.addresses.is_empty());
assert_eq!(initial_config.mtu, 1500);
let initial_networks = next_networks(&mut network_stream).await?;
assert!(initial_networks.network.is_empty());
let start_err = client
.tunnel_client
.tunnel_start(Empty {})
.await
.expect_err("starting without a stored network should fail");
assert_eq!(start_err.code(), tonic::Code::FailedPrecondition);
client
.networks_client
.network_add(Network {
id: 1,
r#type: NetworkType::WireGuard.into(),
payload: sample_wireguard_payload(),
})
.await?;
let networks_after_wg = next_networks(&mut network_stream).await?;
assert_eq!(
network_ids(&networks_after_wg),
vec![(1, NetworkType::WireGuard)]
);
let wireguard_config = next_configuration(&mut config_stream).await?;
assert_eq!(
wireguard_config.addresses,
vec!["10.8.0.2/32", "fd00::2/128"]
);
assert_eq!(wireguard_config.mtu, 1420);
client
.networks_client
.network_add(Network {
id: 2,
r#type: NetworkType::HackClub.into(),
payload: sample_hackclub_payload(),
})
.await?;
let networks_after_mesh_add = next_networks(&mut network_stream).await?;
assert_eq!(
network_ids(&networks_after_mesh_add),
vec![(1, NetworkType::WireGuard), (2, NetworkType::HackClub)]
);
let still_wireguard = next_configuration(&mut config_stream).await?;
assert_eq!(still_wireguard.addresses, wireguard_config.addresses);
client
.networks_client
.network_reorder(NetworkReorderRequest { id: 2, index: 0 })
.await?;
let networks_after_reorder = next_networks(&mut network_stream).await?;
assert_eq!(
network_ids(&networks_after_reorder),
vec![(2, NetworkType::HackClub), (1, NetworkType::WireGuard)]
);
let mesh_config = next_configuration(&mut config_stream).await?;
assert_eq!(mesh_config.addresses, vec!["10.77.0.2/32"]);
assert_eq!(mesh_config.mtu, 1380);
daemon_task.abort();
let _ = daemon_task.await;
cleanup_path(&socket_path);
cleanup_path(&db_path);
Ok(())
}
fn temp_path(ext: &str) -> PathBuf {
let now = SystemTime::now()
.duration_since(UNIX_EPOCH)
.expect("system time is after unix epoch")
.as_nanos();
std::env::temp_dir().join(format!("burrow-daemon-test-{now}.{ext}"))
}
fn cleanup_path(path: &Path) {
let _ = std::fs::remove_file(path);
}
fn sample_wireguard_payload() -> Vec<u8> {
br#"[Interface]
PrivateKey = OEPVdomeLTxTIBvv3TYsJRge0Hp9NMiY0sIrhT8OWG8=
Address = 10.8.0.2/32, fd00::2/128
ListenPort = 51820
MTU = 1420
[Peer]
PublicKey = 8GaFjVO6c4luCHG4ONO+1bFG8tO+Zz5/Gy+Geht1USM=
PresharedKey = ha7j4BjD49sIzyF9SNlbueK0AMHghlj6+u0G3bzC698=
AllowedIPs = 0.0.0.0/0, ::/0
Endpoint = wg.burrow.rs:51820
"#
.to_vec()
}
fn sample_hackclub_payload() -> Vec<u8> {
let endpoint_id = PublicKey::from_bytes(&[0; 32]).unwrap().to_string();
json!({
"endpoint_id": endpoint_id,
"addresses": ["127.0.0.1:7777"],
"local_addresses": ["10.77.0.2/32"],
"mtu": 1380,
"tun_name": "burrow-test-mesh",
})
.to_string()
.into_bytes()
}
async fn next_configuration(
stream: &mut tonic::Streaming<TunnelConfigurationResponse>,
) -> Result<TunnelConfigurationResponse> {
timeout(Duration::from_secs(5), stream.message())
.await??
.ok_or_else(|| anyhow!("configuration stream ended unexpectedly"))
}
async fn next_networks(
stream: &mut tonic::Streaming<NetworkListResponse>,
) -> Result<NetworkListResponse> {
timeout(Duration::from_secs(5), stream.message())
.await??
.ok_or_else(|| anyhow!("network stream ended unexpectedly"))
}
fn network_ids(response: &NetworkListResponse) -> Vec<(i32, NetworkType)> {
response
.network
.iter()
.map(|network| (network.id, network.r#type()))
.collect()
}
}

View file

@ -0,0 +1,180 @@
use std::sync::Arc;
use anyhow::{Context, Result};
use tokio::{sync::RwLock, task::JoinHandle};
use tun::{tokio::TunInterface, TunOptions};
use super::rpc::{
grpc_defs::{Network, NetworkType},
ServerConfig,
};
use crate::{
mesh::iroh::{self as mesh_iroh, HackClubNetworkConfig, MeshHandle},
wireguard::{Config, Interface as WireGuardInterface},
};
#[derive(Clone, Debug, PartialEq, Eq)]
pub enum RuntimeIdentity {
Network {
id: i32,
network_type: NetworkType,
payload: Vec<u8>,
},
}
#[derive(Clone, Debug)]
pub enum ResolvedTunnel {
WireGuard {
identity: RuntimeIdentity,
config: Config,
},
HackClub {
identity: RuntimeIdentity,
config: HackClubNetworkConfig,
},
}
impl ResolvedTunnel {
pub fn from_networks(networks: &[Network]) -> Result<Option<Self>> {
let Some(network) = networks.first() else {
return Ok(None);
};
let identity = RuntimeIdentity::Network {
id: network.id,
network_type: network.r#type(),
payload: network.payload.clone(),
};
match network.r#type() {
NetworkType::WireGuard => {
let payload = String::from_utf8(network.payload.clone())
.context("wireguard payload must be valid UTF-8")?;
let config = Config::from_content_fmt(&payload, "ini")?;
Ok(Some(Self::WireGuard { identity, config }))
}
NetworkType::HackClub => {
let config = HackClubNetworkConfig::from_payload(&network.payload)?;
Ok(Some(Self::HackClub { identity, config }))
}
}
}
pub fn identity(&self) -> &RuntimeIdentity {
match self {
Self::WireGuard { identity, .. } | Self::HackClub { identity, .. } => identity,
}
}
pub fn server_config(&self) -> Result<ServerConfig> {
match self {
Self::WireGuard { config, .. } => ServerConfig::try_from(config),
Self::HackClub { config, .. } => Ok(ServerConfig {
address: config.local_addresses.clone(),
name: config.tun_name.clone(),
mtu: config.mtu.map(i32::from),
}),
}
}
pub async fn start(
self,
tun_interface: Arc<RwLock<Option<TunInterface>>>,
) -> Result<ActiveTunnel> {
match self {
Self::WireGuard { identity, config } => {
let tun = TunOptions::new().open()?;
tun_interface.write().await.replace(tun);
match start_wireguard_runtime(config, tun_interface.clone()).await {
Ok((interface, task)) => {
Ok(ActiveTunnel::WireGuard { identity, interface, task })
}
Err(err) => {
tun_interface.write().await.take();
Err(err)
}
}
}
Self::HackClub { identity, config } => {
let mut tun_opts = TunOptions::new();
if let Some(name) = config.tun_name.as_deref() {
tun_opts = tun_opts.name(name);
}
let tun = tun_opts.open()?;
tun_interface.write().await.replace(tun);
match mesh_iroh::spawn_hackclub_tunnel(config, tun_interface.clone()).await {
Ok(handle) => Ok(ActiveTunnel::HackClub { identity, handle }),
Err(err) => {
tun_interface.write().await.take();
Err(err)
}
}
}
}
}
}
pub enum ActiveTunnel {
WireGuard {
identity: RuntimeIdentity,
interface: Arc<RwLock<WireGuardInterface>>,
task: JoinHandle<Result<()>>,
},
HackClub {
identity: RuntimeIdentity,
handle: MeshHandle,
},
}
impl ActiveTunnel {
pub fn identity(&self) -> &RuntimeIdentity {
match self {
Self::WireGuard { identity, .. } | Self::HackClub { identity, .. } => identity,
}
}
pub async fn shutdown(self, tun_interface: &Arc<RwLock<Option<TunInterface>>>) -> Result<()> {
match self {
Self::WireGuard { interface, task, .. } => {
interface.read().await.remove_tun().await;
let task_result = task.await;
tun_interface.write().await.take();
task_result??;
Ok(())
}
Self::HackClub { handle, .. } => {
let result = handle.shutdown().await;
tun_interface.write().await.take();
result
}
}
}
}
async fn start_wireguard_runtime(
config: Config,
tun_interface: Arc<RwLock<Option<TunInterface>>>,
) -> Result<(Arc<RwLock<WireGuardInterface>>, JoinHandle<Result<()>>)> {
let mut interface: WireGuardInterface = config.try_into()?;
interface.set_tun_ref(tun_interface).await;
let interface = Arc::new(RwLock::new(interface));
let run_interface = interface.clone();
let task = tokio::spawn(async move {
let guard = run_interface.read().await;
guard.run().await
});
Ok((interface, task))
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn no_networks_resolves_to_no_tunnel() {
assert!(ResolvedTunnel::from_networks(&[]).unwrap().is_none());
}
}

View file

@ -5,11 +5,9 @@ use rusqlite::{params, Connection};
use crate::{ use crate::{
daemon::rpc::grpc_defs::{ daemon::rpc::grpc_defs::{
Network as RPCNetwork, Network as RPCNetwork, NetworkDeleteRequest, NetworkReorderRequest, NetworkType,
NetworkDeleteRequest,
NetworkReorderRequest,
NetworkType,
}, },
mesh::iroh::HackClubNetworkConfig,
wireguard::config::{Config, Interface, Peer}, wireguard::config::{Config, Interface, Peer},
}; };
@ -124,35 +122,26 @@ pub fn dump_interface(conn: &Connection, config: &Config) -> Result<()> {
pub fn get_connection(path: Option<&Path>) -> Result<Connection> { pub fn get_connection(path: Option<&Path>) -> Result<Connection> {
let p = path.unwrap_or_else(|| std::path::Path::new(DB_PATH)); let p = path.unwrap_or_else(|| std::path::Path::new(DB_PATH));
if !p.exists() {
let conn = Connection::open(p)?; let conn = Connection::open(p)?;
initialize_tables(&conn)?; initialize_tables(&conn)?;
dump_interface(&conn, &Config::default())?; Ok(conn)
return Ok(conn);
}
Ok(Connection::open(p)?)
} }
pub fn add_network(conn: &Connection, network: &RPCNetwork) -> Result<()> { pub fn add_network(conn: &Connection, network: &RPCNetwork) -> Result<()> {
validate_network_payload(network)?;
let mut stmt = conn.prepare("INSERT INTO network (id, type, payload) VALUES (?, ?, ?)")?; let mut stmt = conn.prepare("INSERT INTO network (id, type, payload) VALUES (?, ?, ?)")?;
stmt.execute(params![ stmt.execute(params![
network.id, network.id,
network.r#type().as_str_name(), network.r#type().as_str_name(),
&network.payload &network.payload
])?; ])?;
if network.r#type() == NetworkType::WireGuard {
let payload_str = String::from_utf8(network.payload.clone())?;
let wg_config = Config::from_content_fmt(&payload_str, "ini")?;
dump_interface(conn, &wg_config)?;
}
Ok(()) Ok(())
} }
pub fn list_networks(conn: &Connection) -> Result<Vec<RPCNetwork>> { pub fn list_networks(conn: &Connection) -> Result<Vec<RPCNetwork>> {
let mut stmt = conn.prepare("SELECT id, type, payload FROM network ORDER BY idx")?; let mut stmt = conn.prepare("SELECT id, type, payload FROM network ORDER BY idx, id")?;
let networks: Vec<RPCNetwork> = stmt let networks: Vec<RPCNetwork> = stmt
.query_map([], |row| { .query_map([], |row| {
println!("row: {:?}", row);
let network_id: i32 = row.get(0)?; let network_id: i32 = row.get(0)?;
let network_type: String = row.get(1)?; let network_type: String = row.get(1)?;
let network_type = NetworkType::from_str_name(network_type.as_str()) let network_type = NetworkType::from_str_name(network_type.as_str())
@ -169,12 +158,19 @@ pub fn list_networks(conn: &Connection) -> Result<Vec<RPCNetwork>> {
} }
pub fn reorder_network(conn: &Connection, req: NetworkReorderRequest) -> Result<()> { pub fn reorder_network(conn: &Connection, req: NetworkReorderRequest) -> Result<()> {
let mut stmt = conn.prepare("UPDATE network SET idx = ? WHERE id = ?")?; let mut ordered_ids = ordered_network_ids(conn)?;
let res = stmt.execute(params![req.index, req.id])?; let Some(current_idx) = ordered_ids.iter().position(|id| *id == req.id) else {
if res == 0 {
return Err(anyhow::anyhow!("No such network exists")); return Err(anyhow::anyhow!("No such network exists"));
} };
Ok(())
let target_idx = usize::try_from(req.index)
.map_err(|_| anyhow::anyhow!("Network index must be non-negative"))?;
let moved_id = ordered_ids.remove(current_idx);
let target_idx = target_idx.min(ordered_ids.len());
ordered_ids.insert(target_idx, moved_id);
renumber_networks(conn, &ordered_ids)
} }
pub fn delete_network(conn: &Connection, req: NetworkDeleteRequest) -> Result<()> { pub fn delete_network(conn: &Connection, req: NetworkDeleteRequest) -> Result<()> {
@ -183,7 +179,8 @@ pub fn delete_network(conn: &Connection, req: NetworkDeleteRequest) -> Result<()
if res == 0 { if res == 0 {
return Err(anyhow::anyhow!("No such network exists")); return Err(anyhow::anyhow!("No such network exists"));
} }
Ok(()) let ordered_ids = ordered_network_ids(conn)?;
renumber_networks(conn, &ordered_ids)
} }
fn parse_lst(s: &str) -> Vec<String> { fn parse_lst(s: &str) -> Vec<String> {
@ -200,9 +197,83 @@ fn to_lst<T: ToString>(v: &Vec<T>) -> String {
.join(",") .join(",")
} }
fn validate_network_payload(network: &RPCNetwork) -> Result<()> {
match network.r#type() {
NetworkType::WireGuard => {
let payload_str = String::from_utf8(network.payload.clone())?;
Config::from_content_fmt(&payload_str, "ini")?;
}
NetworkType::HackClub => {
HackClubNetworkConfig::from_payload(&network.payload)?;
}
}
Ok(())
}
fn ordered_network_ids(conn: &Connection) -> Result<Vec<i32>> {
let mut stmt = conn.prepare("SELECT id FROM network ORDER BY idx, id")?;
let ids = stmt
.query_map([], |row| row.get::<_, i32>(0))?
.collect::<rusqlite::Result<Vec<i32>>>()?;
Ok(ids)
}
fn renumber_networks(conn: &Connection, ordered_ids: &[i32]) -> Result<()> {
conn.execute_batch("BEGIN IMMEDIATE")?;
let result = (|| -> Result<()> {
let mut stmt = conn.prepare("UPDATE network SET idx = ? WHERE id = ?")?;
for (idx, id) in ordered_ids.iter().enumerate() {
stmt.execute(params![idx as i32, id])?;
}
Ok(())
})();
match result {
Ok(()) => {
conn.execute_batch("COMMIT")?;
Ok(())
}
Err(err) => {
let _ = conn.execute_batch("ROLLBACK");
Err(err)
}
}
}
#[cfg(test)] #[cfg(test)]
mod tests { mod tests {
use super::*; use super::*;
use iroh::PublicKey;
use serde_json::json;
use tempfile::tempdir;
fn sample_wireguard_payload() -> Vec<u8> {
br#"[Interface]
PrivateKey = OEPVdomeLTxTIBvv3TYsJRge0Hp9NMiY0sIrhT8OWG8=
Address = 10.13.13.2/24
ListenPort = 51820
[Peer]
PublicKey = 8GaFjVO6c4luCHG4ONO+1bFG8tO+Zz5/Gy+Geht1USM=
PresharedKey = ha7j4BjD49sIzyF9SNlbueK0AMHghlj6+u0G3bzC698=
AllowedIPs = 0.0.0.0/0, 8.8.8.8/32
Endpoint = wg.burrow.rs:51820
"#
.to_vec()
}
fn sample_hackclub_payload(name: &str, address: &str) -> Vec<u8> {
let endpoint_id = PublicKey::from_bytes(&[0; 32]).unwrap().to_string();
json!({
"endpoint_id": endpoint_id,
"addresses": ["127.0.0.1:7777"],
"local_addresses": [address],
"mtu": 1380,
"tun_name": name,
})
.to_string()
.into_bytes()
}
#[test] #[test]
fn test_db() { fn test_db() {
@ -213,4 +284,103 @@ mod tests {
let loaded = load_interface(&conn, "1").unwrap(); let loaded = load_interface(&conn, "1").unwrap();
assert_eq!(config, loaded); assert_eq!(config, loaded);
} }
#[test]
fn add_network_validates_payloads() {
let conn = Connection::open_in_memory().unwrap();
initialize_tables(&conn).unwrap();
add_network(
&conn,
&RPCNetwork {
id: 1,
r#type: NetworkType::WireGuard.into(),
payload: sample_wireguard_payload(),
},
)
.unwrap();
add_network(
&conn,
&RPCNetwork {
id: 2,
r#type: NetworkType::HackClub.into(),
payload: sample_hackclub_payload("burrow-test-0", "10.42.0.2/32"),
},
)
.unwrap();
assert!(add_network(
&conn,
&RPCNetwork {
id: 3,
r#type: NetworkType::WireGuard.into(),
payload: b"not-a-config".to_vec(),
},
)
.is_err());
let ids: Vec<i32> = list_networks(&conn)
.unwrap()
.into_iter()
.map(|n| n.id)
.collect();
assert_eq!(ids, vec![1, 2]);
}
#[test]
fn reorder_and_delete_networks_keep_priority_stable() {
let conn = Connection::open_in_memory().unwrap();
initialize_tables(&conn).unwrap();
for (id, name, address) in [
(1, "burrow-test-1", "10.42.0.2/32"),
(2, "burrow-test-2", "10.42.0.3/32"),
(3, "burrow-test-3", "10.42.0.4/32"),
] {
add_network(
&conn,
&RPCNetwork {
id,
r#type: NetworkType::HackClub.into(),
payload: sample_hackclub_payload(name, address),
},
)
.unwrap();
}
reorder_network(&conn, NetworkReorderRequest { id: 3, index: 0 }).unwrap();
let ids: Vec<i32> = list_networks(&conn)
.unwrap()
.into_iter()
.map(|n| n.id)
.collect();
assert_eq!(ids, vec![3, 1, 2]);
delete_network(&conn, NetworkDeleteRequest { id: 1 }).unwrap();
let ids: Vec<i32> = list_networks(&conn)
.unwrap()
.into_iter()
.map(|n| n.id)
.collect();
assert_eq!(ids, vec![3, 2]);
}
#[test]
fn get_connection_does_not_seed_a_default_interface() {
let dir = tempdir().unwrap();
let db_path = dir.path().join("burrow.sqlite3");
let conn = get_connection(Some(db_path.as_path())).unwrap();
let interface_count: i64 = conn
.query_row("SELECT COUNT(*) FROM wg_interface", [], |row| row.get(0))
.unwrap();
let network_count: i64 = conn
.query_row("SELECT COUNT(*) FROM network", [], |row| row.get(0))
.unwrap();
assert_eq!(interface_count, 0);
assert_eq!(network_count, 0);
}
} }