Implement sending commands via Unix sockets

This commit is contained in:
dav 2023-07-01 09:44:13 -07:00 committed by David Zhong
parent c8df4b860d
commit f869cbdb53
14 changed files with 555 additions and 177 deletions

View file

@ -7,14 +7,17 @@ edition = "2021"
crate-type = ["lib", "staticlib"]
[dependencies]
tokio = { version = "1.21", features = ["rt", "macros"] }
tun = { version = "0.1", path = "../tun" }
tokio = { version = "1.21", features = ["rt", "sync", "io-util", "macros"] }
tun = { version = "0.1", path = "../tun", features = ["serde"] }
clap = { version = "4.3.2", features = ["derive"] }
env_logger = "0.10"
log = "0.4"
serde = { version = "1", features = ["derive"] }
serde_json = "1"
[target.'cfg(target_os = "linux")'.dependencies]
caps = "0.5.5"
libsystemd = "0.6"
[target.'cfg(target_vendor = "apple")'.dependencies]
nix = { version = "0.26.2" }

View file

@ -0,0 +1,13 @@
use serde::{Deserialize, Serialize};
use tun::TunOptions;
#[derive(Debug, Clone, Serialize, Deserialize)]
pub enum DaemonCommand {
Start(DaemonStartOptions),
Stop,
}
#[derive(Debug, Clone, Default, Serialize, Deserialize)]
pub struct DaemonStartOptions {
pub(super) tun: TunOptions,
}

View file

@ -0,0 +1,40 @@
use super::*;
pub struct DaemonInstance {
rx: mpsc::Receiver<DaemonCommand>,
tun_interface: Option<TunInterface>,
}
impl DaemonInstance {
pub fn new(rx: mpsc::Receiver<DaemonCommand>) -> Self {
Self {
rx,
tun_interface: None,
}
}
pub async fn run(&mut self) -> Result<()> {
while let Some(command) = self.rx.recv().await {
match command {
DaemonCommand::Start(options) => {
if self.tun_interface.is_none() {
self.tun_interface = Some(options.tun.open()?);
eprintln!("Daemon starting tun interface.");
} else {
eprintln!("Got start, but tun interface already up.");
}
}
DaemonCommand::Stop => {
if self.tun_interface.is_some() {
self.tun_interface = None;
eprintln!("Daemon stopping tun interface.");
} else {
eprintln!("Got stop, but tun interface is not up.")
}
}
}
}
Ok(())
}
}

19
burrow/src/daemon/mod.rs Normal file
View file

@ -0,0 +1,19 @@
use super::*;
use tokio::sync::mpsc;
mod command;
mod instance;
mod net;
use instance::DaemonInstance;
use net::listen;
pub use command::{DaemonCommand, DaemonStartOptions};
pub use net::DaemonClient;
pub async fn daemon_main() -> Result<()> {
let (tx, rx) = mpsc::channel(2);
let mut inst = DaemonInstance::new(rx);
tokio::try_join!(inst.run(), listen(tx)).map(|_| ())
}

View file

@ -0,0 +1,29 @@
use super::*;
use serde::{Deserialize, Serialize};
#[cfg(target_family = "unix")]
mod unix;
#[cfg(all(target_family = "unix", not(target_os = "linux")))]
pub use unix::{listen, DaemonClient};
#[cfg(target_os = "linux")]
mod systemd;
#[cfg(target_os = "linux")]
pub use systemd::{listen, DaemonClient};
#[cfg(target_os = "windows")]
mod windows;
#[cfg(target_os = "windows")]
pub use windows::{listen, DaemonClient};
#[derive(Clone, Serialize, Deserialize)]
pub struct DaemonRequest {
pub id: u32,
pub command: DaemonCommand,
}
#[derive(Clone, Serialize, Deserialize)]
pub struct DaemonResponse {
// Error types can't be serialized, so this is the second best option.
result: std::result::Result<(), String>,
}

View file

@ -0,0 +1,16 @@
use super::*;
use std::os::fd::IntoRawFd;
pub async fn listen(cmd_tx: mpsc::Sender<DaemonCommand>) -> Result<()> {
if !libsystemd::daemon::booted() || listen_with_systemd(cmd_tx.clone()).await.is_err() {
unix::listen(cmd_tx).await?;
}
Ok(())
}
async fn listen_with_systemd(cmd_tx: mpsc::Sender<DaemonCommand>) -> Result<()> {
let fds = libsystemd::activation::receive_descriptors(false).unwrap();
super::unix::listen_with_optional_fd(cmd_tx, Some(fds[0].clone().into_raw_fd())).await
}
pub type DaemonClient = unix::DaemonClient;

View file

@ -0,0 +1,102 @@
use super::*;
use std::{
os::fd::{FromRawFd, RawFd},
os::unix::net::UnixListener as StdUnixListener,
path::Path,
};
use tokio::{
io::{AsyncBufReadExt, AsyncWriteExt, BufReader},
net::{UnixListener, UnixStream},
};
const UNIX_SOCKET_PATH: &str = "/run/burrow.sock";
pub async fn listen(cmd_tx: mpsc::Sender<DaemonCommand>) -> Result<()> {
listen_with_optional_fd(cmd_tx, None).await
}
pub(crate) async fn listen_with_optional_fd(
cmd_tx: mpsc::Sender<DaemonCommand>,
raw_fd: Option<RawFd>,
) -> Result<()> {
let path = Path::new(UNIX_SOCKET_PATH);
let listener = if let Some(raw_fd) = raw_fd {
let listener = unsafe { StdUnixListener::from_raw_fd(raw_fd) };
listener.set_nonblocking(true)?;
UnixListener::from_std(listener)
} else {
UnixListener::bind(path)
};
let listener = if let Ok(listener) = listener {
listener
} else {
// Won't help all that much, if we use the async version of fs.
std::fs::remove_file(path)?;
UnixListener::bind(path)?
};
loop {
let (stream, _) = listener.accept().await?;
let cmd_tx = cmd_tx.clone();
// I'm pretty sure we won't need to manually join / shut this down,
// `lines` will return Err during dropping, and this task should exit gracefully.
tokio::task::spawn(async {
let cmd_tx = cmd_tx;
let mut stream = stream;
let (mut read_stream, mut write_stream) = stream.split();
let buf_reader = BufReader::new(&mut read_stream);
let mut lines = buf_reader.lines();
while let Ok(Some(line)) = lines.next_line().await {
let mut res = DaemonResponse { result: Ok(()) };
let command = match serde_json::from_str::<DaemonRequest>(&line) {
Ok(req) => Some(req.command),
Err(e) => {
res.result = Err(format!("{}", e));
None
}
};
let mut res = serde_json::to_string(&res).unwrap();
res.push('\n');
write_stream.write_all(res.as_bytes()).await.unwrap();
// I want this to come at the very end so that we always send a reponse back.
if let Some(command) = command {
cmd_tx.send(command).await.unwrap();
}
}
});
}
}
pub struct DaemonClient {
connection: UnixStream,
}
impl DaemonClient {
pub async fn new() -> Result<Self> {
Self::new_with_path(UNIX_SOCKET_PATH).await
}
pub async fn new_with_path(path: &str) -> Result<Self> {
let path = Path::new(path);
let connection = UnixStream::connect(path).await?;
Ok(Self { connection })
}
pub async fn send_command(&mut self, command: DaemonCommand) -> Result<()> {
let mut command = serde_json::to_string(&DaemonRequest { id: 0, command })?;
command.push('\n');
self.connection.write_all(command.as_bytes()).await?;
let buf_reader = BufReader::new(&mut self.connection);
let mut lines = buf_reader.lines();
// This unwrap *should* never cause issues.
let response = lines.next_line().await?.unwrap();
let res: DaemonResponse = serde_json::from_str(&response)?;
res.result.unwrap();
Ok(())
}
}

View file

@ -0,0 +1,17 @@
use super::*;
pub async fn listen(_: mpsc::Sender<DaemonCommand>) -> Result<()> {
unimplemented!("This platform does not currently support daemon mode.")
}
pub struct DaemonClient;
impl DaemonClient {
pub async fn new() -> Result<Self> {
unimplemented!("This platform does not currently support daemon mode.")
}
pub async fn send_command(&mut self, _: DaemonCommand) -> Result<()> {
unimplemented!("This platform does not currently support daemon mode.")
}
}

View file

@ -8,6 +8,10 @@ use tokio::io::Result;
use burrow::retrieve;
use tun::TunInterface;
mod daemon;
use daemon::{DaemonClient, DaemonCommand, DaemonStartOptions};
#[derive(Parser)]
#[command(name = "Burrow")]
#[command(author = "Hack Club <team@hackclub.com>")]
@ -30,6 +34,10 @@ enum Commands {
Start(StartArgs),
/// Retrieve the file descriptor of the tun interface
Retrieve(RetrieveArgs),
/// Stop Burrow daemon
Stop,
/// Start Burrow daemon
Daemon(DaemonArgs),
}
#[derive(Args)]
@ -38,14 +46,15 @@ struct StartArgs {}
#[derive(Args)]
struct RetrieveArgs {}
#[derive(Args)]
struct DaemonArgs {}
#[cfg(any(target_os = "linux", target_vendor = "apple"))]
async fn try_start() -> Result<()> {
burrow::ensureroot::ensure_root();
let iface = TunInterface::new()?;
println!("{:?}", iface.name());
let iface2 = retrieve();
println!("{}", iface2);
Ok(())
let mut client = DaemonClient::new().await?;
client
.send_command(DaemonCommand::Start(DaemonStartOptions::default()))
.await
}
#[cfg(any(target_os = "linux", target_vendor = "apple"))]
@ -56,6 +65,13 @@ async fn try_retrieve() -> Result<()> {
Ok(())
}
#[cfg(any(target_os = "linux", target_vendor = "apple"))]
async fn try_stop() -> Result<()> {
let mut client = DaemonClient::new().await?;
client.send_command(DaemonCommand::Stop).await?;
Ok(())
}
#[cfg(not(any(target_os = "linux", target_vendor = "apple")))]
async fn try_start() -> Result<()> {
Ok(())
@ -66,8 +82,13 @@ async fn try_retrieve() -> Result<()> {
Ok(())
}
#[cfg(not(any(target_os = "linux", target_vendor = "apple")))]
async fn try_stop() -> Result<()> {
Ok(())
}
#[tokio::main(flavor = "current_thread")]
async fn main() {
async fn main() -> Result<()> {
println!("Platform: {}", std::env::consts::OS);
let cli = Cli::parse();
@ -80,5 +101,11 @@ async fn main() {
try_retrieve().await.unwrap();
println!("FINISHED");
}
Commands::Stop => {
try_stop().await.unwrap();
}
Commands::Daemon(_) => daemon::daemon_main().await?,
}
Ok(())
}