diff --git a/burrow/src/daemon/instance.rs b/burrow/src/daemon/instance.rs index ce96fa5..fdcd95f 100644 --- a/burrow/src/daemon/instance.rs +++ b/burrow/src/daemon/instance.rs @@ -1,48 +1,27 @@ use std::{ - ops::Deref, path::{Path, PathBuf}, sync::Arc, - time::Duration, }; use anyhow::Result; use rusqlite::Connection; -use tokio::sync::{mpsc, watch, Notify, RwLock}; +use tokio::sync::{mpsc, watch, RwLock}; use tokio_stream::wrappers::ReceiverStream; use tonic::{Request, Response, Status as RspStatus}; -use tracing::{debug, info, warn}; -use tun::{tokio::TunInterface, TunOptions}; +use tracing::warn; +use tun::tokio::TunInterface; -use super::rpc::grpc_defs::{ - networks_server::Networks, - tunnel_server::Tunnel, - Empty, - Network, - NetworkDeleteRequest, - NetworkListResponse, - NetworkReorderRequest, - State as RPCTunnelState, - TunnelConfigurationResponse, - TunnelStatusResponse, +use super::{ + rpc::grpc_defs::{ + networks_server::Networks, tunnel_server::Tunnel, Empty, Network, NetworkDeleteRequest, + NetworkListResponse, NetworkReorderRequest, State as RPCTunnelState, + TunnelConfigurationResponse, TunnelStatusResponse, + }, + runtime::{ActiveTunnel, ResolvedTunnel}, }; use crate::{ - daemon::rpc::{ - DaemonCommand, - DaemonNotification, - DaemonResponse, - DaemonResponseData, - ServerConfig, - ServerInfo, - }, - database::{ - add_network, - delete_network, - get_connection, - list_networks, - load_interface, - reorder_network, - }, - wireguard::{Config, Interface}, + daemon::rpc::ServerConfig, + database::{add_network, delete_network, get_connection, list_networks, reorder_network}, }; #[derive(Debug, Clone)] @@ -52,10 +31,10 @@ enum RunState { } impl RunState { - pub fn to_rpc(&self) -> RPCTunnelState { + fn to_rpc(&self) -> RPCTunnelState { match self { - RunState::Running => RPCTunnelState::Running, - RunState::Idle => RPCTunnelState::Stopped, + Self::Running => RPCTunnelState::Running, + Self::Idle => RPCTunnelState::Stopped, } } } @@ -63,30 +42,24 @@ impl RunState { #[derive(Clone)] pub struct DaemonRPCServer { tun_interface: Arc>>, - wg_interface: Arc>, - config: Arc>, db_path: Option, wg_state_chan: (watch::Sender, watch::Receiver), network_update_chan: (watch::Sender<()>, watch::Receiver<()>), + active_tunnel: Arc>>, } impl DaemonRPCServer { - pub fn new( - wg_interface: Arc>, - config: Arc>, - db_path: Option<&Path>, - ) -> Result { + pub fn new(db_path: Option<&Path>) -> Result { Ok(Self { tun_interface: Arc::new(RwLock::new(None)), - wg_interface, - config, - db_path: db_path.map(|p| p.to_owned()), + db_path: db_path.map(Path::to_owned), wg_state_chan: watch::channel(RunState::Idle), network_update_chan: watch::channel(()), + active_tunnel: Arc::new(RwLock::new(None)), }) } - pub fn get_connection(&self) -> Result { + fn get_connection(&self) -> Result { get_connection(self.db_path.as_deref()).map_err(proc_err) } @@ -94,13 +67,71 @@ impl DaemonRPCServer { self.wg_state_chan.0.send(state).map_err(proc_err) } - async fn get_wg_state(&self) -> RunState { - self.wg_state_chan.1.borrow().to_owned() - } - async fn notify_network_update(&self) -> Result<(), RspStatus> { self.network_update_chan.0.send(()).map_err(proc_err) } + + async fn resolve_tunnel(&self) -> Result, RspStatus> { + let conn = self.get_connection()?; + let networks = list_networks(&conn).map_err(proc_err)?; + ResolvedTunnel::from_networks(&networks).map_err(proc_err) + } + + async fn current_tunnel_configuration(&self) -> Result { + match self.resolve_tunnel().await? { + Some(config) => { + let config = config.server_config().map_err(proc_err)?; + Ok(configuration_rsp(config)) + } + None => Ok(empty_configuration_rsp()), + } + } + + async fn stop_active_tunnel(&self) -> Result { + let current = { self.active_tunnel.write().await.take() }; + let Some(current) = current else { + return Ok(false); + }; + + current + .shutdown(&self.tun_interface) + .await + .map_err(proc_err)?; + self.set_wg_state(RunState::Idle).await?; + Ok(true) + } + + async fn replace_active_tunnel(&self, desired: ResolvedTunnel) -> Result<(), RspStatus> { + let _ = self.stop_active_tunnel().await?; + let active = desired + .start(self.tun_interface.clone()) + .await + .map_err(proc_err)?; + self.active_tunnel.write().await.replace(active); + self.set_wg_state(RunState::Running).await?; + Ok(()) + } + + async fn reconcile_runtime(&self) -> Result<(), RspStatus> { + let desired = self.resolve_tunnel().await?; + let Some(desired) = desired else { + let _ = self.stop_active_tunnel().await?; + return Ok(()); + }; + let needs_restart = { + let guard = self.active_tunnel.read().await; + guard + .as_ref() + .map(|active| active.identity() != desired.identity()) + .unwrap_or(false) + }; + + if needs_restart { + self.replace_active_tunnel(desired).await?; + } + + Ok(()) + } } #[tonic::async_trait] @@ -113,55 +144,49 @@ impl Tunnel for DaemonRPCServer { _request: Request, ) -> Result, RspStatus> { let (tx, rx) = mpsc::channel(10); + let server = self.clone(); + let mut sub = self.network_update_chan.1.clone(); + tokio::spawn(async move { - let serv_config = ServerConfig::default(); - tx.send(Ok(TunnelConfigurationResponse { - mtu: serv_config.mtu.unwrap_or(1000), - addresses: serv_config.address, - })) - .await + loop { + let response = server.current_tunnel_configuration().await; + if tx.send(response).await.is_err() { + break; + } + if sub.changed().await.is_err() { + break; + } + } }); + Ok(Response::new(ReceiverStream::new(rx))) } async fn tunnel_start(&self, _request: Request) -> Result, RspStatus> { - let wg_state = self.get_wg_state().await; - match wg_state { - RunState::Idle => { - let tun_if = TunOptions::new().open()?; - debug!("Setting tun on wg_interface"); - self.tun_interface.write().await.replace(tun_if); - self.wg_interface - .write() - .await - .set_tun_ref(self.tun_interface.clone()) - .await; - debug!("tun set on wg_interface"); + let desired = self + .resolve_tunnel() + .await? + .ok_or_else(|| RspStatus::failed_precondition("no stored network configured"))?; + let already_running = { + let guard = self.active_tunnel.read().await; + guard + .as_ref() + .map(|active| active.identity() == desired.identity()) + .unwrap_or(false) + }; - debug!("Setting tun_interface"); - debug!("tun_interface set: {:?}", self.tun_interface); - - debug!("Cloning wg_interface"); - let tmp_wg = self.wg_interface.clone(); - let run_task = tokio::spawn(async move { - let twlock = tmp_wg.read().await; - twlock.run().await - }); - self.set_wg_state(RunState::Running).await?; - } - - RunState::Running => { - warn!("Got start, but tun interface already up."); - } + if already_running { + warn!("Got start, but active tunnel already matches desired network."); + return Ok(Response::new(Empty {})); } - return Ok(Response::new(Empty {})); + self.replace_active_tunnel(desired).await?; + Ok(Response::new(Empty {})) } async fn tunnel_stop(&self, _request: Request) -> Result, RspStatus> { - self.wg_interface.write().await.remove_tun().await; - self.set_wg_state(RunState::Idle).await?; - return Ok(Response::new(Empty {})); + let _ = self.stop_active_tunnel().await?; + Ok(Response::new(Empty {})) } async fn tunnel_status( @@ -172,13 +197,16 @@ impl Tunnel for DaemonRPCServer { let mut state_rx = self.wg_state_chan.1.clone(); tokio::spawn(async move { let cur = state_rx.borrow_and_update().to_owned(); - tx.send(Ok(status_rsp(cur))).await; + if tx.send(Ok(status_rsp(cur))).await.is_err() { + return; + } + loop { - state_rx.changed().await.unwrap(); + if state_rx.changed().await.is_err() { + break; + } let cur = state_rx.borrow().to_owned(); - let res = tx.send(Ok(status_rsp(cur))).await; - if res.is_err() { - eprintln!("Tunnel status channel closed"); + if tx.send(Ok(status_rsp(cur))).await.is_err() { break; } } @@ -196,6 +224,7 @@ impl Networks for DaemonRPCServer { let network = request.into_inner(); add_network(&conn, &network).map_err(proc_err)?; self.notify_network_update().await?; + self.reconcile_runtime().await?; Ok(Response::new(Empty {})) } @@ -203,7 +232,6 @@ impl Networks for DaemonRPCServer { &self, _request: Request, ) -> Result, RspStatus> { - debug!("Mock network_list called"); let (tx, rx) = mpsc::channel(10); let conn = self.get_connection()?; let mut sub = self.network_update_chan.1.clone(); @@ -212,12 +240,12 @@ impl Networks for DaemonRPCServer { let networks = list_networks(&conn) .map(|res| NetworkListResponse { network: res }) .map_err(proc_err); - let res = tx.send(networks).await; - if res.is_err() { - eprintln!("Network list channel closed"); + if tx.send(networks).await.is_err() { + break; + } + if sub.changed().await.is_err() { break; } - sub.changed().await.unwrap(); } }); Ok(Response::new(ReceiverStream::new(rx))) @@ -230,6 +258,7 @@ impl Networks for DaemonRPCServer { let conn = self.get_connection()?; reorder_network(&conn, request.into_inner()).map_err(proc_err)?; self.notify_network_update().await?; + self.reconcile_runtime().await?; Ok(Response::new(Empty {})) } @@ -240,6 +269,7 @@ impl Networks for DaemonRPCServer { let conn = self.get_connection()?; delete_network(&conn, request.into_inner()).map_err(proc_err)?; self.notify_network_update().await?; + self.reconcile_runtime().await?; Ok(Response::new(Empty {})) } } @@ -248,6 +278,20 @@ fn proc_err(err: impl ToString) -> RspStatus { RspStatus::internal(err.to_string()) } +fn configuration_rsp(config: ServerConfig) -> TunnelConfigurationResponse { + TunnelConfigurationResponse { + mtu: config.mtu.unwrap_or(1000), + addresses: config.address, + } +} + +fn empty_configuration_rsp() -> TunnelConfigurationResponse { + TunnelConfigurationResponse { + mtu: 1500, + addresses: Vec::new(), + } +} + fn status_rsp(state: RunState) -> TunnelStatusResponse { TunnelStatusResponse { state: state.to_rpc().into(), diff --git a/burrow/src/daemon/mod.rs b/burrow/src/daemon/mod.rs index f6b973f..f5ad7d3 100644 --- a/burrow/src/daemon/mod.rs +++ b/burrow/src/daemon/mod.rs @@ -4,23 +4,20 @@ pub mod apple; mod instance; mod net; pub mod rpc; +mod runtime; use anyhow::{Error as AhError, Result}; use instance::DaemonRPCServer; pub use net::{get_socket_path, DaemonClient}; pub use rpc::{DaemonCommand, DaemonResponseData, DaemonStartOptions}; -use tokio::{ - net::UnixListener, - sync::{Notify, RwLock}, -}; +use tokio::{net::UnixListener, sync::Notify}; use tokio_stream::wrappers::UnixListenerStream; use tonic::transport::Server; -use tracing::{error, info}; +use tracing::info; use crate::{ daemon::rpc::grpc_defs::{networks_server::NetworksServer, tunnel_server::TunnelServer}, - database::{get_connection, load_interface}, - wireguard::Interface, + database::get_connection, }; pub async fn daemon_main( @@ -28,16 +25,8 @@ pub async fn daemon_main( db_path: Option<&Path>, notify_ready: Option>, ) -> Result<()> { - if let Some(n) = notify_ready { - n.notify_one() - } - let conn = get_connection(db_path)?; - let config = load_interface(&conn, "1")?; - let burrow_server = DaemonRPCServer::new( - Arc::new(RwLock::new(config.clone().try_into()?)), - Arc::new(RwLock::new(config)), - db_path.clone(), - )?; + let _conn = get_connection(db_path)?; + let burrow_server = DaemonRPCServer::new(db_path)?; let spp = socket_path.clone(); let tmp = get_socket_path(); let sock_path = spp.unwrap_or(Path::new(tmp.as_str())); @@ -55,9 +44,213 @@ pub async fn daemon_main( Ok::<(), AhError>(()) }); + if let Some(n) = notify_ready { + n.notify_one(); + } + info!("Starting daemon..."); tokio::try_join!(serve_job) .map(|_| ()) .map_err(|e| e.into()) } + +#[cfg(test)] +mod tests { + use std::{ + path::PathBuf, + time::{SystemTime, UNIX_EPOCH}, + }; + + use anyhow::{anyhow, Result}; + use iroh::PublicKey; + use serde_json::json; + use tokio::time::{timeout, Duration}; + + use super::*; + use crate::daemon::rpc::{ + client::BurrowClient, + grpc_defs::{ + Empty, Network, NetworkListResponse, NetworkReorderRequest, NetworkType, + TunnelConfigurationResponse, + }, + }; + + #[tokio::test] + async fn daemon_tracks_network_priority_via_grpc() -> Result<()> { + let socket_path = temp_path("sock"); + let db_path = temp_path("sqlite3"); + let ready = Arc::new(Notify::new()); + + let daemon_ready = ready.clone(); + let daemon_socket_path = socket_path.clone(); + let daemon_db_path = db_path.clone(); + let daemon_task = tokio::spawn(async move { + daemon_main( + Some(daemon_socket_path.as_path()), + Some(daemon_db_path.as_path()), + Some(daemon_ready), + ) + .await + }); + + timeout(Duration::from_secs(5), ready.notified()).await?; + + let mut client = timeout( + Duration::from_secs(5), + BurrowClient::from_uds_path(&socket_path), + ) + .await??; + let mut config_stream = client + .tunnel_client + .tunnel_configuration(Empty {}) + .await? + .into_inner(); + let mut network_stream = client + .networks_client + .network_list(Empty {}) + .await? + .into_inner(); + + let initial_config = next_configuration(&mut config_stream).await?; + assert!(initial_config.addresses.is_empty()); + assert_eq!(initial_config.mtu, 1500); + + let initial_networks = next_networks(&mut network_stream).await?; + assert!(initial_networks.network.is_empty()); + + let start_err = client + .tunnel_client + .tunnel_start(Empty {}) + .await + .expect_err("starting without a stored network should fail"); + assert_eq!(start_err.code(), tonic::Code::FailedPrecondition); + + client + .networks_client + .network_add(Network { + id: 1, + r#type: NetworkType::WireGuard.into(), + payload: sample_wireguard_payload(), + }) + .await?; + + let networks_after_wg = next_networks(&mut network_stream).await?; + assert_eq!( + network_ids(&networks_after_wg), + vec![(1, NetworkType::WireGuard)] + ); + + let wireguard_config = next_configuration(&mut config_stream).await?; + assert_eq!( + wireguard_config.addresses, + vec!["10.8.0.2/32", "fd00::2/128"] + ); + assert_eq!(wireguard_config.mtu, 1420); + + client + .networks_client + .network_add(Network { + id: 2, + r#type: NetworkType::HackClub.into(), + payload: sample_hackclub_payload(), + }) + .await?; + + let networks_after_mesh_add = next_networks(&mut network_stream).await?; + assert_eq!( + network_ids(&networks_after_mesh_add), + vec![(1, NetworkType::WireGuard), (2, NetworkType::HackClub)] + ); + + let still_wireguard = next_configuration(&mut config_stream).await?; + assert_eq!(still_wireguard.addresses, wireguard_config.addresses); + + client + .networks_client + .network_reorder(NetworkReorderRequest { id: 2, index: 0 }) + .await?; + + let networks_after_reorder = next_networks(&mut network_stream).await?; + assert_eq!( + network_ids(&networks_after_reorder), + vec![(2, NetworkType::HackClub), (1, NetworkType::WireGuard)] + ); + + let mesh_config = next_configuration(&mut config_stream).await?; + assert_eq!(mesh_config.addresses, vec!["10.77.0.2/32"]); + assert_eq!(mesh_config.mtu, 1380); + + daemon_task.abort(); + let _ = daemon_task.await; + cleanup_path(&socket_path); + cleanup_path(&db_path); + + Ok(()) + } + + fn temp_path(ext: &str) -> PathBuf { + let now = SystemTime::now() + .duration_since(UNIX_EPOCH) + .expect("system time is after unix epoch") + .as_nanos(); + std::env::temp_dir().join(format!("burrow-daemon-test-{now}.{ext}")) + } + + fn cleanup_path(path: &Path) { + let _ = std::fs::remove_file(path); + } + + fn sample_wireguard_payload() -> Vec { + br#"[Interface] +PrivateKey = OEPVdomeLTxTIBvv3TYsJRge0Hp9NMiY0sIrhT8OWG8= +Address = 10.8.0.2/32, fd00::2/128 +ListenPort = 51820 +MTU = 1420 + +[Peer] +PublicKey = 8GaFjVO6c4luCHG4ONO+1bFG8tO+Zz5/Gy+Geht1USM= +PresharedKey = ha7j4BjD49sIzyF9SNlbueK0AMHghlj6+u0G3bzC698= +AllowedIPs = 0.0.0.0/0, ::/0 +Endpoint = wg.burrow.rs:51820 +"# + .to_vec() + } + + fn sample_hackclub_payload() -> Vec { + let endpoint_id = PublicKey::from_bytes(&[0; 32]).unwrap().to_string(); + json!({ + "endpoint_id": endpoint_id, + "addresses": ["127.0.0.1:7777"], + "local_addresses": ["10.77.0.2/32"], + "mtu": 1380, + "tun_name": "burrow-test-mesh", + }) + .to_string() + .into_bytes() + } + + async fn next_configuration( + stream: &mut tonic::Streaming, + ) -> Result { + timeout(Duration::from_secs(5), stream.message()) + .await?? + .ok_or_else(|| anyhow!("configuration stream ended unexpectedly")) + } + + async fn next_networks( + stream: &mut tonic::Streaming, + ) -> Result { + timeout(Duration::from_secs(5), stream.message()) + .await?? + .ok_or_else(|| anyhow!("network stream ended unexpectedly")) + } + + fn network_ids(response: &NetworkListResponse) -> Vec<(i32, NetworkType)> { + response + .network + .iter() + .map(|network| (network.id, network.r#type())) + .collect() + } +} diff --git a/burrow/src/daemon/runtime.rs b/burrow/src/daemon/runtime.rs new file mode 100644 index 0000000..31c0b0a --- /dev/null +++ b/burrow/src/daemon/runtime.rs @@ -0,0 +1,180 @@ +use std::sync::Arc; + +use anyhow::{Context, Result}; +use tokio::{sync::RwLock, task::JoinHandle}; +use tun::{tokio::TunInterface, TunOptions}; + +use super::rpc::{ + grpc_defs::{Network, NetworkType}, + ServerConfig, +}; +use crate::{ + mesh::iroh::{self as mesh_iroh, HackClubNetworkConfig, MeshHandle}, + wireguard::{Config, Interface as WireGuardInterface}, +}; + +#[derive(Clone, Debug, PartialEq, Eq)] +pub enum RuntimeIdentity { + Network { + id: i32, + network_type: NetworkType, + payload: Vec, + }, +} + +#[derive(Clone, Debug)] +pub enum ResolvedTunnel { + WireGuard { + identity: RuntimeIdentity, + config: Config, + }, + HackClub { + identity: RuntimeIdentity, + config: HackClubNetworkConfig, + }, +} + +impl ResolvedTunnel { + pub fn from_networks(networks: &[Network]) -> Result> { + let Some(network) = networks.first() else { + return Ok(None); + }; + + let identity = RuntimeIdentity::Network { + id: network.id, + network_type: network.r#type(), + payload: network.payload.clone(), + }; + + match network.r#type() { + NetworkType::WireGuard => { + let payload = String::from_utf8(network.payload.clone()) + .context("wireguard payload must be valid UTF-8")?; + let config = Config::from_content_fmt(&payload, "ini")?; + Ok(Some(Self::WireGuard { identity, config })) + } + NetworkType::HackClub => { + let config = HackClubNetworkConfig::from_payload(&network.payload)?; + Ok(Some(Self::HackClub { identity, config })) + } + } + } + + pub fn identity(&self) -> &RuntimeIdentity { + match self { + Self::WireGuard { identity, .. } | Self::HackClub { identity, .. } => identity, + } + } + + pub fn server_config(&self) -> Result { + match self { + Self::WireGuard { config, .. } => ServerConfig::try_from(config), + Self::HackClub { config, .. } => Ok(ServerConfig { + address: config.local_addresses.clone(), + name: config.tun_name.clone(), + mtu: config.mtu.map(i32::from), + }), + } + } + + pub async fn start( + self, + tun_interface: Arc>>, + ) -> Result { + match self { + Self::WireGuard { identity, config } => { + let tun = TunOptions::new().open()?; + tun_interface.write().await.replace(tun); + + match start_wireguard_runtime(config, tun_interface.clone()).await { + Ok((interface, task)) => { + Ok(ActiveTunnel::WireGuard { identity, interface, task }) + } + Err(err) => { + tun_interface.write().await.take(); + Err(err) + } + } + } + Self::HackClub { identity, config } => { + let mut tun_opts = TunOptions::new(); + if let Some(name) = config.tun_name.as_deref() { + tun_opts = tun_opts.name(name); + } + + let tun = tun_opts.open()?; + tun_interface.write().await.replace(tun); + + match mesh_iroh::spawn_hackclub_tunnel(config, tun_interface.clone()).await { + Ok(handle) => Ok(ActiveTunnel::HackClub { identity, handle }), + Err(err) => { + tun_interface.write().await.take(); + Err(err) + } + } + } + } + } +} + +pub enum ActiveTunnel { + WireGuard { + identity: RuntimeIdentity, + interface: Arc>, + task: JoinHandle>, + }, + HackClub { + identity: RuntimeIdentity, + handle: MeshHandle, + }, +} + +impl ActiveTunnel { + pub fn identity(&self) -> &RuntimeIdentity { + match self { + Self::WireGuard { identity, .. } | Self::HackClub { identity, .. } => identity, + } + } + + pub async fn shutdown(self, tun_interface: &Arc>>) -> Result<()> { + match self { + Self::WireGuard { interface, task, .. } => { + interface.read().await.remove_tun().await; + let task_result = task.await; + tun_interface.write().await.take(); + task_result??; + Ok(()) + } + Self::HackClub { handle, .. } => { + let result = handle.shutdown().await; + tun_interface.write().await.take(); + result + } + } + } +} + +async fn start_wireguard_runtime( + config: Config, + tun_interface: Arc>>, +) -> Result<(Arc>, JoinHandle>)> { + let mut interface: WireGuardInterface = config.try_into()?; + interface.set_tun_ref(tun_interface).await; + let interface = Arc::new(RwLock::new(interface)); + let run_interface = interface.clone(); + let task = tokio::spawn(async move { + let guard = run_interface.read().await; + guard.run().await + }); + Ok((interface, task)) +} + +#[cfg(test)] +mod tests { + use super::*; + + #[test] + fn no_networks_resolves_to_no_tunnel() { + assert!(ResolvedTunnel::from_networks(&[]).unwrap().is_none()); + } +} diff --git a/burrow/src/database.rs b/burrow/src/database.rs index 9a9aac3..c03048c 100644 --- a/burrow/src/database.rs +++ b/burrow/src/database.rs @@ -5,11 +5,9 @@ use rusqlite::{params, Connection}; use crate::{ daemon::rpc::grpc_defs::{ - Network as RPCNetwork, - NetworkDeleteRequest, - NetworkReorderRequest, - NetworkType, + Network as RPCNetwork, NetworkDeleteRequest, NetworkReorderRequest, NetworkType, }, + mesh::iroh::HackClubNetworkConfig, wireguard::config::{Config, Interface, Peer}, }; @@ -124,35 +122,26 @@ pub fn dump_interface(conn: &Connection, config: &Config) -> Result<()> { pub fn get_connection(path: Option<&Path>) -> Result { let p = path.unwrap_or_else(|| std::path::Path::new(DB_PATH)); - if !p.exists() { - let conn = Connection::open(p)?; - initialize_tables(&conn)?; - dump_interface(&conn, &Config::default())?; - return Ok(conn); - } - Ok(Connection::open(p)?) + let conn = Connection::open(p)?; + initialize_tables(&conn)?; + Ok(conn) } pub fn add_network(conn: &Connection, network: &RPCNetwork) -> Result<()> { + validate_network_payload(network)?; let mut stmt = conn.prepare("INSERT INTO network (id, type, payload) VALUES (?, ?, ?)")?; stmt.execute(params![ network.id, network.r#type().as_str_name(), &network.payload ])?; - if network.r#type() == NetworkType::WireGuard { - let payload_str = String::from_utf8(network.payload.clone())?; - let wg_config = Config::from_content_fmt(&payload_str, "ini")?; - dump_interface(conn, &wg_config)?; - } Ok(()) } pub fn list_networks(conn: &Connection) -> Result> { - let mut stmt = conn.prepare("SELECT id, type, payload FROM network ORDER BY idx")?; + let mut stmt = conn.prepare("SELECT id, type, payload FROM network ORDER BY idx, id")?; let networks: Vec = stmt .query_map([], |row| { - println!("row: {:?}", row); let network_id: i32 = row.get(0)?; let network_type: String = row.get(1)?; let network_type = NetworkType::from_str_name(network_type.as_str()) @@ -169,12 +158,19 @@ pub fn list_networks(conn: &Connection) -> Result> { } pub fn reorder_network(conn: &Connection, req: NetworkReorderRequest) -> Result<()> { - let mut stmt = conn.prepare("UPDATE network SET idx = ? WHERE id = ?")?; - let res = stmt.execute(params![req.index, req.id])?; - if res == 0 { + let mut ordered_ids = ordered_network_ids(conn)?; + let Some(current_idx) = ordered_ids.iter().position(|id| *id == req.id) else { return Err(anyhow::anyhow!("No such network exists")); - } - Ok(()) + }; + + let target_idx = usize::try_from(req.index) + .map_err(|_| anyhow::anyhow!("Network index must be non-negative"))?; + + let moved_id = ordered_ids.remove(current_idx); + let target_idx = target_idx.min(ordered_ids.len()); + ordered_ids.insert(target_idx, moved_id); + + renumber_networks(conn, &ordered_ids) } pub fn delete_network(conn: &Connection, req: NetworkDeleteRequest) -> Result<()> { @@ -183,7 +179,8 @@ pub fn delete_network(conn: &Connection, req: NetworkDeleteRequest) -> Result<() if res == 0 { return Err(anyhow::anyhow!("No such network exists")); } - Ok(()) + let ordered_ids = ordered_network_ids(conn)?; + renumber_networks(conn, &ordered_ids) } fn parse_lst(s: &str) -> Vec { @@ -200,9 +197,83 @@ fn to_lst(v: &Vec) -> String { .join(",") } +fn validate_network_payload(network: &RPCNetwork) -> Result<()> { + match network.r#type() { + NetworkType::WireGuard => { + let payload_str = String::from_utf8(network.payload.clone())?; + Config::from_content_fmt(&payload_str, "ini")?; + } + NetworkType::HackClub => { + HackClubNetworkConfig::from_payload(&network.payload)?; + } + } + Ok(()) +} + +fn ordered_network_ids(conn: &Connection) -> Result> { + let mut stmt = conn.prepare("SELECT id FROM network ORDER BY idx, id")?; + let ids = stmt + .query_map([], |row| row.get::<_, i32>(0))? + .collect::>>()?; + Ok(ids) +} + +fn renumber_networks(conn: &Connection, ordered_ids: &[i32]) -> Result<()> { + conn.execute_batch("BEGIN IMMEDIATE")?; + let result = (|| -> Result<()> { + let mut stmt = conn.prepare("UPDATE network SET idx = ? WHERE id = ?")?; + for (idx, id) in ordered_ids.iter().enumerate() { + stmt.execute(params![idx as i32, id])?; + } + Ok(()) + })(); + + match result { + Ok(()) => { + conn.execute_batch("COMMIT")?; + Ok(()) + } + Err(err) => { + let _ = conn.execute_batch("ROLLBACK"); + Err(err) + } + } +} + #[cfg(test)] mod tests { use super::*; + use iroh::PublicKey; + use serde_json::json; + use tempfile::tempdir; + + fn sample_wireguard_payload() -> Vec { + br#"[Interface] +PrivateKey = OEPVdomeLTxTIBvv3TYsJRge0Hp9NMiY0sIrhT8OWG8= +Address = 10.13.13.2/24 +ListenPort = 51820 + +[Peer] +PublicKey = 8GaFjVO6c4luCHG4ONO+1bFG8tO+Zz5/Gy+Geht1USM= +PresharedKey = ha7j4BjD49sIzyF9SNlbueK0AMHghlj6+u0G3bzC698= +AllowedIPs = 0.0.0.0/0, 8.8.8.8/32 +Endpoint = wg.burrow.rs:51820 +"# + .to_vec() + } + + fn sample_hackclub_payload(name: &str, address: &str) -> Vec { + let endpoint_id = PublicKey::from_bytes(&[0; 32]).unwrap().to_string(); + json!({ + "endpoint_id": endpoint_id, + "addresses": ["127.0.0.1:7777"], + "local_addresses": [address], + "mtu": 1380, + "tun_name": name, + }) + .to_string() + .into_bytes() + } #[test] fn test_db() { @@ -213,4 +284,103 @@ mod tests { let loaded = load_interface(&conn, "1").unwrap(); assert_eq!(config, loaded); } + + #[test] + fn add_network_validates_payloads() { + let conn = Connection::open_in_memory().unwrap(); + initialize_tables(&conn).unwrap(); + + add_network( + &conn, + &RPCNetwork { + id: 1, + r#type: NetworkType::WireGuard.into(), + payload: sample_wireguard_payload(), + }, + ) + .unwrap(); + + add_network( + &conn, + &RPCNetwork { + id: 2, + r#type: NetworkType::HackClub.into(), + payload: sample_hackclub_payload("burrow-test-0", "10.42.0.2/32"), + }, + ) + .unwrap(); + + assert!(add_network( + &conn, + &RPCNetwork { + id: 3, + r#type: NetworkType::WireGuard.into(), + payload: b"not-a-config".to_vec(), + }, + ) + .is_err()); + + let ids: Vec = list_networks(&conn) + .unwrap() + .into_iter() + .map(|n| n.id) + .collect(); + assert_eq!(ids, vec![1, 2]); + } + + #[test] + fn reorder_and_delete_networks_keep_priority_stable() { + let conn = Connection::open_in_memory().unwrap(); + initialize_tables(&conn).unwrap(); + + for (id, name, address) in [ + (1, "burrow-test-1", "10.42.0.2/32"), + (2, "burrow-test-2", "10.42.0.3/32"), + (3, "burrow-test-3", "10.42.0.4/32"), + ] { + add_network( + &conn, + &RPCNetwork { + id, + r#type: NetworkType::HackClub.into(), + payload: sample_hackclub_payload(name, address), + }, + ) + .unwrap(); + } + + reorder_network(&conn, NetworkReorderRequest { id: 3, index: 0 }).unwrap(); + let ids: Vec = list_networks(&conn) + .unwrap() + .into_iter() + .map(|n| n.id) + .collect(); + assert_eq!(ids, vec![3, 1, 2]); + + delete_network(&conn, NetworkDeleteRequest { id: 1 }).unwrap(); + let ids: Vec = list_networks(&conn) + .unwrap() + .into_iter() + .map(|n| n.id) + .collect(); + assert_eq!(ids, vec![3, 2]); + } + + #[test] + fn get_connection_does_not_seed_a_default_interface() { + let dir = tempdir().unwrap(); + let db_path = dir.path().join("burrow.sqlite3"); + + let conn = get_connection(Some(db_path.as_path())).unwrap(); + + let interface_count: i64 = conn + .query_row("SELECT COUNT(*) FROM wg_interface", [], |row| row.get(0)) + .unwrap(); + let network_count: i64 = conn + .query_row("SELECT COUNT(*) FROM network", [], |row| row.get(0)) + .unwrap(); + + assert_eq!(interface_count, 0); + assert_eq!(network_count, 0); + } }