Compare commits

..

No commits in common. "4408e9aca8bc05cf508d5374e00020348acb326b" and "4dd31d5f1e0838fdd8a55d0d445e120bf8d1943e" have entirely different histories.

38 changed files with 509 additions and 597 deletions

View file

@ -113,7 +113,7 @@ final class BurrowIpc {
return data return data
} }
func request<U: Decodable>(_ request: any Request, type: U.Type) async throws -> U { func request<U: Decodable>(_ request: Request, type: U.Type) async throws -> U {
do { do {
var data: Data = try JSONEncoder().encode(request) var data: Data = try JSONEncoder().encode(request)
data.append(contentsOf: [10]) data.append(contentsOf: [10])

View file

@ -7,40 +7,16 @@ enum BurrowError: Error {
case resultIsNone case resultIsNone
} }
protocol Request: Codable where T: Codable{ protocol Request: Codable {
associatedtype T
var id: UInt { get set } var id: UInt { get set }
var command: T { get set } var command: String { get set }
} }
struct BurrowSingleCommand: Request { struct BurrowRequest: Request {
var id: UInt var id: UInt
var command: String var command: String
} }
struct BurrowRequest<T>: Request where T: Codable{
var id: UInt
var command: T
}
struct BurrowStartRequest: Codable {
struct TunOptions: Codable{
let name: String?
let no_pi: Bool
let tun_excl: Bool
let seek_utun: Int?
let address: String?
}
struct StartOptions: Codable{
let tun: TunOptions
}
let Start: StartOptions
}
func start_req_fd(id: UInt, fd: Int) -> BurrowRequest<BurrowStartRequest> {
return BurrowRequest(id: id, command: BurrowStartRequest(Start: BurrowStartRequest.StartOptions(tun: BurrowStartRequest.TunOptions(name: nil, no_pi: false, tun_excl: false, seek_utun: fd, address: nil))))
}
struct Response<T>: Decodable where T: Decodable { struct Response<T>: Decodable where T: Decodable {
var id: UInt var id: UInt
var result: T var result: T

View file

@ -6,7 +6,7 @@ class PacketTunnelProvider: NEPacketTunnelProvider {
let logger = Logger(subsystem: "com.hackclub.burrow", category: "frontend") let logger = Logger(subsystem: "com.hackclub.burrow", category: "frontend")
var client: BurrowIpc? var client: BurrowIpc?
var osInitialized = false var osInitialized = false
override func startTunnel(options: [String : NSObject]? = nil) async throws { override func startTunnel(options: [String: NSObject]?, completionHandler: @escaping (Error?) -> Void) {
logger.log("Starting tunnel") logger.log("Starting tunnel")
if !osInitialized { if !osInitialized {
libburrow.initialize_oslog() libburrow.initialize_oslog()
@ -15,35 +15,28 @@ class PacketTunnelProvider: NEPacketTunnelProvider {
libburrow.start_srv() libburrow.start_srv()
client = BurrowIpc(logger: logger) client = BurrowIpc(logger: logger)
logger.info("Started server") logger.info("Started server")
do { Task {
let command = BurrowSingleCommand(id: 0, command: "ServerConfig") do {
guard let data = try await client?.request(command, type: Response<BurrowResult<ServerConfigData>>.self) let command = BurrowRequest(id: 0, command: "ServerConfig")
else { guard let data = try await client?.request(command, type: Response<BurrowResult<ServerConfigData>>.self)
throw BurrowError.cantParseResult else {
throw BurrowError.cantParseResult
}
let encoded = try JSONEncoder().encode(data.result)
self.logger.log("Received final data: \(String(decoding: encoded, as: UTF8.self))")
guard let serverconfig = data.result.Ok else {
throw BurrowError.resultIsError
}
guard let tunNs = self.generateTunSettings(from: serverconfig) else {
throw BurrowError.addrDoesntExist
}
try await self.setTunnelNetworkSettings(tunNs)
self.logger.info("Set remote tunnel address to \(tunNs.tunnelRemoteAddress)")
completionHandler(nil)
} catch {
self.logger.error("An error occurred: \(error)")
completionHandler(error)
} }
let encoded = try JSONEncoder().encode(data.result)
self.logger.log("Received final data: \(String(decoding: encoded, as: UTF8.self))")
guard let serverconfig = data.result.Ok else {
throw BurrowError.resultIsError
}
guard let tunNs = self.generateTunSettings(from: serverconfig) else {
throw BurrowError.addrDoesntExist
}
try await self.setTunnelNetworkSettings(tunNs)
self.logger.info("Set remote tunnel address to \(tunNs.tunnelRemoteAddress)")
// let tunFd = self.packetFlow.value(forKeyPath: "socket.fileDescriptor") as! Int;
// self.logger.info("Found File Descriptor: \(tunFd)")
let start_command = start_req_fd(id: 1, fd: 0)
guard let data = try await client?.request(start_command, type: Response<BurrowResult<String>>.self)
else {
throw BurrowError.cantParseResult
}
let encoded_startres = try JSONEncoder().encode(data.result)
self.logger.log("Received start server response: \(String(decoding: encoded_startres, as: UTF8.self))")
} catch {
self.logger.error("An error occurred: \(error)")
throw error
} }
} }
private func generateTunSettings(from: ServerConfigData) -> NETunnelNetworkSettings? { private func generateTunSettings(from: ServerConfigData) -> NETunnelNetworkSettings? {
@ -57,16 +50,17 @@ class PacketTunnelProvider: NEPacketTunnelProvider {
logger.log("Initialized ipv4 settings: \(nst.ipv4Settings)") logger.log("Initialized ipv4 settings: \(nst.ipv4Settings)")
return nst return nst
} }
override func stopTunnel(with reason: NEProviderStopReason) async { override func stopTunnel(with reason: NEProviderStopReason, completionHandler: @escaping () -> Void) {
completionHandler()
} }
override func handleAppMessage(_ messageData: Data) async -> Data? { override func handleAppMessage(_ messageData: Data, completionHandler: ((Data?) -> Void)?) {
messageData if let handler = completionHandler {
handler(messageData)
}
} }
override func sleep() async { override func sleep(completionHandler: @escaping () -> Void) {
completionHandler()
} }
override func wake() { override func wake() {
} }
} }

View file

@ -1,13 +1,15 @@
use tracing::debug; use tracing::{debug, Subscriber};
use tracing::instrument::WithSubscriber;
use tracing_oslog::OsLogger; use tracing_oslog::OsLogger;
use tracing_subscriber::FmtSubscriber;
use tracing_subscriber::layer::SubscriberExt; use tracing_subscriber::layer::SubscriberExt;
pub use crate::daemon::start_srv; pub use crate::daemon::start_srv;
#[no_mangle] #[no_mangle]
pub extern "C" fn initialize_oslog() { pub extern "C" fn initialize_oslog() {
let collector = let collector = tracing_subscriber::registry()
tracing_subscriber::registry().with(OsLogger::new("com.hackclub.burrow", "backend")); .with(OsLogger::new("com.hackclub.burrow", "backend"));
tracing::subscriber::set_global_default(collector).unwrap(); tracing::subscriber::set_global_default(collector).unwrap();
debug!("Initialized oslog tracing in libburrow rust FFI"); debug!("Initialized oslog tracing in libburrow rust FFI");
} }

View file

@ -12,22 +12,21 @@ pub enum DaemonCommand {
#[derive(Debug, Clone, Default, Serialize, Deserialize, JsonSchema)] #[derive(Debug, Clone, Default, Serialize, Deserialize, JsonSchema)]
pub struct DaemonStartOptions { pub struct DaemonStartOptions {
pub tun: TunOptions, pub(super) tun: TunOptions,
} }
#[test] #[test]
fn test_daemoncommand_serialization() { fn test_daemoncommand_serialization() {
insta::assert_snapshot!(serde_json::to_string(&DaemonCommand::Start(
DaemonStartOptions::default()
))
.unwrap());
insta::assert_snapshot!( insta::assert_snapshot!(
serde_json::to_string(&DaemonCommand::Start(DaemonStartOptions { serde_json::to_string(&DaemonCommand::Start(DaemonStartOptions::default())).unwrap()
tun: TunOptions { ..TunOptions::default() }
}))
.unwrap()
); );
insta::assert_snapshot!(serde_json::to_string(&DaemonCommand::ServerInfo).unwrap()); insta::assert_snapshot!(
insta::assert_snapshot!(serde_json::to_string(&DaemonCommand::Stop).unwrap()); serde_json::to_string(&DaemonCommand::ServerInfo).unwrap()
insta::assert_snapshot!(serde_json::to_string(&DaemonCommand::ServerConfig).unwrap()) );
} insta::assert_snapshot!(
serde_json::to_string(&DaemonCommand::Stop).unwrap()
);
insta::assert_snapshot!(
serde_json::to_string(&DaemonCommand::ServerConfig).unwrap()
)
}

View file

@ -1,105 +1,50 @@
use std::sync::Arc;
use anyhow::Result;
use tokio::{sync::RwLock, task::JoinHandle};
use tracing::{debug, info, warn}; use tracing::{debug, info, warn};
use tun::tokio::TunInterface; use DaemonResponse;
use tun::TunInterface;
use crate::{ use crate::daemon::response::{DaemonResponseData, ServerConfig, ServerInfo};
daemon::{ use super::*;
command::DaemonCommand,
response::{DaemonResponse, DaemonResponseData, ServerConfig, ServerInfo},
},
wireguard::Interface,
};
enum RunState {
Running(JoinHandle<Result<()>>),
Idle,
}
pub struct DaemonInstance { pub struct DaemonInstance {
rx: async_channel::Receiver<DaemonCommand>, rx: async_channel::Receiver<DaemonCommand>,
sx: async_channel::Sender<DaemonResponse>, sx: async_channel::Sender<DaemonResponse>,
tun_interface: Option<Arc<RwLock<TunInterface>>>, tun_interface: Option<TunInterface>,
wg_interface: Arc<RwLock<Interface>>,
wg_state: RunState,
} }
impl DaemonInstance { impl DaemonInstance {
pub fn new( pub fn new(rx: async_channel::Receiver<DaemonCommand>, sx: async_channel::Sender<DaemonResponse>) -> Self {
rx: async_channel::Receiver<DaemonCommand>,
sx: async_channel::Sender<DaemonResponse>,
wg_interface: Arc<RwLock<Interface>>,
) -> Self {
Self { Self {
rx, rx,
sx, sx,
wg_interface,
tun_interface: None, tun_interface: None,
wg_state: RunState::Idle,
} }
} }
pub fn set_tun_interface(&mut self, tun_interface: Arc<RwLock<TunInterface>>) {
self.tun_interface = Some(tun_interface);
}
async fn proc_command(&mut self, command: DaemonCommand) -> Result<DaemonResponseData> { async fn proc_command(&mut self, command: DaemonCommand) -> Result<DaemonResponseData> {
info!("Daemon got command: {:?}", command); info!("Daemon got command: {:?}", command);
match command { match command {
DaemonCommand::Start(st) => { DaemonCommand::Start(st) => {
match self.wg_state { if self.tun_interface.is_none() {
RunState::Running(_) => { debug!("Daemon attempting start tun interface.");
warn!("Got start, but tun interface already up."); self.tun_interface = Some(st.tun.open()?);
} info!("Daemon started tun interface");
RunState::Idle => { } else {
let raw = tun::TunInterface::retrieve().unwrap(); warn!("Got start, but tun interface already up.");
debug!("TunInterface retrieved: {:?}", raw.name()?);
let retrieved = TunInterface::new(raw)?;
let tun_if = Arc::new(RwLock::new(retrieved));
// let tun_if = Arc::new(RwLock::new(TunInterface::new(st.tun.open()?)?));
debug!("Setting tun_interface");
self.tun_interface = Some(tun_if.clone());
debug!("tun_interface set: {:?}", self.tun_interface);
debug!("Setting tun on wg_interface");
self.wg_interface.write().await.set_tun(tun_if);
debug!("tun set on wg_interface");
debug!("Cloning wg_interface");
let tmp_wg = self.wg_interface.clone();
debug!("wg_interface cloned");
debug!("Spawning run task");
let run_task = tokio::spawn(async move {
debug!("Running wg_interface");
let twlock = tmp_wg.read().await;
debug!("wg_interface read lock acquired");
twlock.run().await
});
debug!("Run task spawned: {:?}", run_task);
debug!("Setting wg_state to Running");
self.wg_state = RunState::Running(run_task);
debug!("wg_state set to Running");
info!("Daemon started tun interface");
}
} }
Ok(DaemonResponseData::None) Ok(DaemonResponseData::None)
} }
DaemonCommand::ServerInfo => match &self.tun_interface { DaemonCommand::ServerInfo => {
None => Ok(DaemonResponseData::None), match &self.tun_interface {
Some(ti) => { None => {Ok(DaemonResponseData::None)}
info!("{:?}", ti); Some(ti) => {
Ok(DaemonResponseData::ServerInfo(ServerInfo::try_from( info!("{:?}", ti);
ti.read().await.inner.get_ref(), Ok(
)?)) DaemonResponseData::ServerInfo(
ServerInfo::try_from(ti)?
)
)
}
} }
}, }
DaemonCommand::Stop => { DaemonCommand::Stop => {
if self.tun_interface.is_some() { if self.tun_interface.is_some() {
self.tun_interface = None; self.tun_interface = None;
@ -116,7 +61,6 @@ impl DaemonInstance {
} }
pub async fn run(&mut self) -> Result<()> { pub async fn run(&mut self) -> Result<()> {
tracing::info!("BEGIN");
while let Ok(command) = self.rx.recv().await { while let Ok(command) = self.rx.recv().await {
let response = self.proc_command(command).await; let response = self.proc_command(command).await;
info!("Daemon response: {:?}", response); info!("Daemon response: {:?}", response);

View file

@ -1,7 +1,5 @@
use std::{ use std::net::{Ipv4Addr, SocketAddr, ToSocketAddrs};
net::{IpAddr, Ipv4Addr, SocketAddr},
sync::Arc,
};
mod command; mod command;
mod instance; mod instance;
@ -12,18 +10,15 @@ use anyhow::{Error, Result};
use base64::{engine::general_purpose, Engine as _}; use base64::{engine::general_purpose, Engine as _};
pub use command::{DaemonCommand, DaemonStartOptions}; pub use command::{DaemonCommand, DaemonStartOptions};
use fehler::throws; use fehler::throws;
use instance::DaemonInstance;
use ip_network::{IpNetwork, Ipv4Network}; use ip_network::{IpNetwork, Ipv4Network};
use instance::DaemonInstance;
use crate::wireguard::{StaticSecret, Peer, Interface, PublicKey};
pub use net::DaemonClient;
#[cfg(target_vendor = "apple")] #[cfg(target_vendor = "apple")]
pub use net::start_srv; pub use net::start_srv;
pub use net::DaemonClient;
pub use response::{DaemonResponse, DaemonResponseData, ServerInfo};
use tokio::sync::RwLock;
use crate::{ pub use response::{DaemonResponseData, DaemonResponse, ServerInfo};
daemon::net::listen,
wireguard::{Interface, Peer, PublicKey, StaticSecret},
};
#[throws] #[throws]
fn parse_key(string: &str) -> [u8; 32] { fn parse_key(string: &str) -> [u8; 32] {
@ -48,13 +43,18 @@ fn parse_public_key(string: &str) -> PublicKey {
pub async fn daemon_main() -> Result<()> { pub async fn daemon_main() -> Result<()> {
let (commands_tx, commands_rx) = async_channel::unbounded(); let (commands_tx, commands_rx) = async_channel::unbounded();
let (response_tx, response_rx) = async_channel::unbounded(); let (response_tx, response_rx) = async_channel::unbounded();
let mut inst = DaemonInstance::new(commands_rx, response_tx);
let mut _tun = tun::TunInterface::new()?;
_tun.set_ipv4_addr(Ipv4Addr::from([192, 168, 1, 10]))?;
_tun.set_timeout(Some(std::time::Duration::from_secs(1)))?;
let tun = tun::tokio::TunInterface::new(_tun)?;
let private_key = parse_secret_key("GNqIAOCRxjl/cicZyvkvpTklgQuUmGUIEkH7IXF/sEE=")?; let private_key = parse_secret_key("GNqIAOCRxjl/cicZyvkvpTklgQuUmGUIEkH7IXF/sEE=")?;
let public_key = parse_public_key("uy75leriJay0+oHLhRMpV+A5xAQ0hCJ+q7Ww81AOvT4=")?; let public_key = parse_public_key("uy75leriJay0+oHLhRMpV+A5xAQ0hCJ+q7Ww81AOvT4=")?;
let preshared_key = Some(parse_key("s7lx/mg+reVEMnGnqeyYOQkzD86n2+gYnx1M9ygi08k=")?); let preshared_key = Some(parse_key("s7lx/mg+reVEMnGnqeyYOQkzD86n2+gYnx1M9ygi08k=")?);
let endpoint = SocketAddr::new(IpAddr::V4(Ipv4Addr::new(198, 18, 6, 180)), 51820); // DNS lookup under macos fails, somehow let endpoint = "wg.burrow.rs:51820".to_socket_addrs()?.next().unwrap();
let iface = Interface::new(tun, vec![Peer {
let iface = Interface::new(vec![Peer {
endpoint, endpoint,
private_key, private_key,
public_key, public_key,
@ -62,25 +62,6 @@ pub async fn daemon_main() -> Result<()> {
allowed_ips: vec![IpNetwork::V4(Ipv4Network::DEFAULT_ROUTE)], allowed_ips: vec![IpNetwork::V4(Ipv4Network::DEFAULT_ROUTE)],
}])?; }])?;
let mut inst: DaemonInstance = iface.run().await;
DaemonInstance::new(commands_rx, response_tx, Arc::new(RwLock::new(iface)));
tracing::info!("Starting daemon jobs...");
let inst_job = tokio::spawn(async move {
let res = inst.run().await;
if let Err(e) = res {
tracing::error!("Error when running instance: {}", e);
}
});
let listen_job = tokio::spawn(async move {
let res = listen(commands_tx, response_rx).await;
if let Err(e) = res {
tracing::error!("Error when listening: {}", e);
}
});
tokio::try_join!(inst_job, listen_job).map(|_| ());
Ok(()) Ok(())
} }

View file

@ -1,13 +1,10 @@
use std::thread; use std::thread;
use tokio::runtime::Runtime; use tokio::runtime::Runtime;
use tracing::{error, info}; use tracing::error;
use crate::daemon::{daemon_main, DaemonClient}; use crate::daemon::{daemon_main, DaemonClient};
#[no_mangle] #[no_mangle]
pub extern "C" fn start_srv() { pub extern "C" fn start_srv(){
info!("Rust: Starting server");
let _handle = thread::spawn(move || { let _handle = thread::spawn(move || {
let rt = Runtime::new().unwrap(); let rt = Runtime::new().unwrap();
rt.block_on(async { rt.block_on(async {
@ -19,12 +16,9 @@ pub extern "C" fn start_srv() {
let rt = Runtime::new().unwrap(); let rt = Runtime::new().unwrap();
rt.block_on(async { rt.block_on(async {
loop { loop {
match DaemonClient::new().await { if let Ok(_) = DaemonClient::new().await{
Ok(_) => break, break
Err(_e) => {
// error!("Error when connecting to daemon: {}", e)
}
} }
} }
}); });
} }

View file

@ -29,3 +29,4 @@ pub struct DaemonRequest {
pub id: u32, pub id: u32,
pub command: DaemonCommand, pub command: DaemonCommand,
} }

View file

@ -1,23 +1,13 @@
pub async fn listen( pub async fn listen(cmd_tx: async_channel::Sender<DaemonCommand>, rsp_rx: async_channel::Receiver<DaemonResponse>) -> Result<()> {
cmd_tx: async_channel::Sender<DaemonCommand>, if !libsystemd::daemon::booted() || listen_with_systemd(cmd_tx.clone(), rsp_rx.clone()).await.is_err() {
rsp_rx: async_channel::Receiver<DaemonResponse>,
) -> Result<()> {
if !libsystemd::daemon::booted()
|| listen_with_systemd(cmd_tx.clone(), rsp_rx.clone())
.await
.is_err()
{
unix::listen(cmd_tx, rsp_rx).await?; unix::listen(cmd_tx, rsp_rx).await?;
} }
Ok(()) Ok(())
} }
async fn listen_with_systemd( async fn listen_with_systemd(cmd_tx: async_channel::Sender<DaemonCommand>, rsp_rx: async_channel::Receiver<DaemonResponse>) -> Result<()> {
cmd_tx: async_channel::Sender<DaemonCommand>,
rsp_rx: async_channel::Receiver<DaemonResponse>,
) -> Result<()> {
let fds = libsystemd::activation::receive_descriptors(false)?; let fds = libsystemd::activation::receive_descriptors(false)?;
super::unix::listen_with_optional_fd(cmd_tx, rsp_rx, Some(fds[0].clone().into_raw_fd())).await super::unix::listen_with_optional_fd(cmd_tx, rsp_rx,Some(fds[0].clone().into_raw_fd())).await
} }
pub type DaemonClient = unix::DaemonClient; pub type DaemonClient = unix::DaemonClient;

View file

@ -1,20 +1,22 @@
use super::*;
use std::{ use std::{
io, ascii, io, os::{
os::{
fd::{FromRawFd, RawFd}, fd::{FromRawFd, RawFd},
unix::net::UnixListener as StdUnixListener, unix::net::UnixListener as StdUnixListener,
}, },
path::{Path, PathBuf}, path::Path};
}; use std::hash::Hash;
use std::path::PathBuf;
use anyhow::anyhow;
use log::log;
use tracing::info;
use anyhow::{anyhow, Result}; use anyhow::Result;
use tokio::{ use tokio::{
io::{AsyncBufReadExt, AsyncWriteExt, BufReader}, io::{AsyncBufReadExt, AsyncWriteExt, BufReader},
net::{UnixListener, UnixStream}, net::{UnixListener, UnixStream},
}; };
use tracing::{debug, info}; use tracing::debug;
use super::*;
use crate::daemon::{DaemonCommand, DaemonResponse, DaemonResponseData}; use crate::daemon::{DaemonCommand, DaemonResponse, DaemonResponseData};
#[cfg(not(target_vendor = "apple"))] #[cfg(not(target_vendor = "apple"))]
@ -24,33 +26,28 @@ const UNIX_SOCKET_PATH: &str = "/run/burrow.sock";
const UNIX_SOCKET_PATH: &str = "burrow.sock"; const UNIX_SOCKET_PATH: &str = "burrow.sock";
#[cfg(target_os = "macos")] #[cfg(target_os = "macos")]
fn fetch_socket_path() -> Option<PathBuf> { fn fetch_socket_path() -> Option<PathBuf>{
let tries = vec![ let tries = vec![
"burrow.sock".to_string(), "burrow.sock".to_string(),
format!( format!("{}/Library/Containers/com.hackclub.burrow.network/Data/burrow.sock",
"{}/Library/Containers/com.hackclub.burrow.network/Data/burrow.sock", std::env::var("HOME").unwrap_or_default())
std::env::var("HOME").unwrap_or_default() .to_string(),
)
.to_string(),
]; ];
for path in tries { for path in tries{
let path = PathBuf::from(path); let path = PathBuf::from(path);
if path.exists() { if path.exists(){
return Some(path) return Some(path);
} }
} }
None None
} }
#[cfg(not(target_os = "macos"))] #[cfg(not(target_os = "macos"))]
fn fetch_socket_path() -> Option<PathBuf> { fn fetch_socket_path() -> Option<PathBuf>{
Some(Path::new(UNIX_SOCKET_PATH).to_path_buf()) Some(Path::new(UNIX_SOCKET_PATH).to_path_buf())
} }
pub async fn listen( pub async fn listen(cmd_tx: async_channel::Sender<DaemonCommand>, rsp_rx: async_channel::Receiver<DaemonResponse>) -> Result<()> {
cmd_tx: async_channel::Sender<DaemonCommand>,
rsp_rx: async_channel::Receiver<DaemonResponse>,
) -> Result<()> {
listen_with_optional_fd(cmd_tx, rsp_rx, None).await listen_with_optional_fd(cmd_tx, rsp_rx, None).await
} }
@ -72,12 +69,14 @@ pub(crate) async fn listen_with_optional_fd(
listener listener
} else { } else {
// Won't help all that much, if we use the async version of fs. // Won't help all that much, if we use the async version of fs.
if let Some(par) = path.parent() { if let Some(par) = path.parent(){
std::fs::create_dir_all(par)?; std::fs::create_dir_all(
par
)?;
} }
match std::fs::remove_file(path) { match std::fs::remove_file(path){
Err(e) if e.kind() == io::ErrorKind::NotFound => Ok(()), Err(e) if e.kind()==io::ErrorKind::NotFound => {Ok(())}
stuff => stuff, stuff => stuff
}?; }?;
info!("Relative path: {}", path.to_string_lossy()); info!("Relative path: {}", path.to_string_lossy());
UnixListener::bind(path)? UnixListener::bind(path)?
@ -99,18 +98,18 @@ pub(crate) async fn listen_with_optional_fd(
while let Ok(Some(line)) = lines.next_line().await { while let Ok(Some(line)) = lines.next_line().await {
info!("Got line: {}", line); info!("Got line: {}", line);
debug!("Line raw data: {:?}", line.as_bytes()); debug!("Line raw data: {:?}", line.as_bytes());
let mut res: DaemonResponse = DaemonResponseData::None.into(); let mut res : DaemonResponse = DaemonResponseData::None.into();
let req = match serde_json::from_str::<DaemonRequest>(&line) { let req = match serde_json::from_str::<DaemonRequest>(&line) {
Ok(req) => Some(req), Ok(req) => Some(req),
Err(e) => { Err(e) => {
res.result = Err(e.to_string()); res.result = Err(e.to_string());
tracing::error!("Failed to parse request: {}", e);
None None
} }
}; };
let mut res = serde_json::to_string(&res).unwrap(); let mut res = serde_json::to_string(&res).unwrap();
res.push('\n'); res.push('\n');
if let Some(req) = req { if let Some(req) = req {
cmd_tx.send(req.command).await.unwrap(); cmd_tx.send(req.command).await.unwrap();
let res = rsp_rxc.recv().await.unwrap().with_id(req.id); let res = rsp_rxc.recv().await.unwrap().with_id(req.id);
@ -118,8 +117,6 @@ pub(crate) async fn listen_with_optional_fd(
retres.push('\n'); retres.push('\n');
info!("Sending response: {}", retres); info!("Sending response: {}", retres);
write_stream.write_all(retres.as_bytes()).await.unwrap(); write_stream.write_all(retres.as_bytes()).await.unwrap();
} else {
write_stream.write_all(res.as_bytes()).await.unwrap();
} }
} }
}); });
@ -132,7 +129,8 @@ pub struct DaemonClient {
impl DaemonClient { impl DaemonClient {
pub async fn new() -> Result<Self> { pub async fn new() -> Result<Self> {
let path = fetch_socket_path().ok_or(anyhow!("Failed to find socket path"))?; let path = fetch_socket_path()
.ok_or(anyhow!("Failed to find socket path"))?;
// debug!("found path: {:?}", path); // debug!("found path: {:?}", path);
let connection = UnixStream::connect(path).await?; let connection = UnixStream::connect(path).await?;
debug!("connected to socket"); debug!("connected to socket");

View file

@ -1,9 +1,6 @@
use super::*; use super::*;
pub async fn listen( pub async fn listen(_cmd_tx: async_channel::Sender<DaemonCommand>, _rsp_rx: async_channel::Receiver<DaemonResponse>) -> Result<()> {
_cmd_tx: async_channel::Sender<DaemonCommand>,
_rsp_rx: async_channel::Receiver<DaemonResponse>,
) -> Result<()> {
unimplemented!("This platform does not currently support daemon mode.") unimplemented!("This platform does not currently support daemon mode.")
} }

View file

@ -1,3 +1,4 @@
use anyhow::anyhow;
use schemars::JsonSchema; use schemars::JsonSchema;
use serde::{Deserialize, Serialize}; use serde::{Deserialize, Serialize};
use tun::TunInterface; use tun::TunInterface;
@ -6,27 +7,30 @@ use tun::TunInterface;
pub struct DaemonResponse { pub struct DaemonResponse {
// Error types can't be serialized, so this is the second best option. // Error types can't be serialized, so this is the second best option.
pub result: Result<DaemonResponseData, String>, pub result: Result<DaemonResponseData, String>,
pub id: u32, pub id: u32
} }
impl DaemonResponse { impl DaemonResponse{
pub fn new(result: Result<DaemonResponseData, impl ToString>) -> Self { pub fn new(result: Result<DaemonResponseData, impl ToString>) -> Self{
Self { Self{
result: result.map_err(|e| e.to_string()), result: result.map_err(|e| e.to_string()),
id: 0, id: 0
} }
} }
} }
impl Into<DaemonResponse> for DaemonResponseData { impl Into<DaemonResponse> for DaemonResponseData{
fn into(self) -> DaemonResponse { fn into(self) -> DaemonResponse{
DaemonResponse::new(Ok::<DaemonResponseData, String>(self)) DaemonResponse::new(Ok::<DaemonResponseData, String>(self))
} }
} }
impl DaemonResponse { impl DaemonResponse{
pub fn with_id(self, id: u32) -> Self { pub fn with_id(self, id: u32) -> Self{
Self { id, ..self } Self {
id,
..self
}
} }
} }
@ -34,22 +38,24 @@ impl DaemonResponse {
pub struct ServerInfo { pub struct ServerInfo {
pub name: Option<String>, pub name: Option<String>,
pub ip: Option<String>, pub ip: Option<String>,
pub mtu: Option<i32>, pub mtu: Option<i32>
} }
impl TryFrom<&TunInterface> for ServerInfo { impl TryFrom<&TunInterface> for ServerInfo{
type Error = anyhow::Error; type Error = anyhow::Error;
#[cfg(any(target_os = "linux", target_vendor = "apple"))] #[cfg(any(target_os="linux",target_vendor="apple"))]
fn try_from(server: &TunInterface) -> anyhow::Result<Self> { fn try_from(server: &TunInterface) -> anyhow::Result<Self> {
Ok(ServerInfo { Ok(
name: server.name().ok(), ServerInfo{
ip: server.ipv4_addr().ok().map(|ip| ip.to_string()), name: server.name().ok(),
mtu: server.mtu().ok(), ip: server.ipv4_addr().ok().map(|ip| ip.to_string()),
}) mtu: server.mtu().ok()
}
)
} }
#[cfg(not(any(target_os = "linux", target_vendor = "apple")))] #[cfg(not(any(target_os="linux",target_vendor="apple")))]
fn try_from(server: &TunInterface) -> anyhow::Result<Self> { fn try_from(server: &TunInterface) -> anyhow::Result<Self> {
Err(anyhow!("Not implemented in this platform")) Err(anyhow!("Not implemented in this platform"))
} }
@ -59,55 +65,45 @@ impl TryFrom<&TunInterface> for ServerInfo {
pub struct ServerConfig { pub struct ServerConfig {
pub address: Option<String>, pub address: Option<String>,
pub name: Option<String>, pub name: Option<String>,
pub mtu: Option<i32>, pub mtu: Option<i32>
} }
impl Default for ServerConfig { impl Default for ServerConfig {
fn default() -> Self { fn default() -> Self {
Self { Self{
address: Some("10.13.13.2".to_string()), // Dummy remote address address: Some("10.0.0.1".to_string()), // Dummy remote address
name: None, name: None,
mtu: None, mtu: None
} }
} }
} }
#[derive(Clone, Debug, Serialize, Deserialize, JsonSchema)] #[derive(Clone, Debug, Serialize, Deserialize, JsonSchema)]
pub enum DaemonResponseData { pub enum DaemonResponseData{
ServerInfo(ServerInfo), ServerInfo(ServerInfo),
ServerConfig(ServerConfig), ServerConfig(ServerConfig),
None, None
} }
#[test] #[test]
fn test_response_serialization() -> anyhow::Result<()> { fn test_response_serialization() -> anyhow::Result<()>{
insta::assert_snapshot!(serde_json::to_string(&DaemonResponse::new(Ok::< insta::assert_snapshot!(
DaemonResponseData, serde_json::to_string(&DaemonResponse::new(Ok::<DaemonResponseData, String>(DaemonResponseData::None)))?
String, );
>( insta::assert_snapshot!(
DaemonResponseData::None serde_json::to_string(&DaemonResponse::new(Ok::<DaemonResponseData, String>(DaemonResponseData::ServerInfo(ServerInfo{
)))?);
insta::assert_snapshot!(serde_json::to_string(&DaemonResponse::new(Ok::<
DaemonResponseData,
String,
>(
DaemonResponseData::ServerInfo(ServerInfo {
name: Some("burrow".to_string()), name: Some("burrow".to_string()),
ip: None, ip: None,
mtu: Some(1500) mtu: Some(1500)
}) }))))?
)))?); );
insta::assert_snapshot!(serde_json::to_string(&DaemonResponse::new(Err::< insta::assert_snapshot!(
DaemonResponseData, serde_json::to_string(&DaemonResponse::new(Err::<DaemonResponseData, String>("error".to_string())))?
String, );
>( insta::assert_snapshot!(
"error".to_string() serde_json::to_string(&DaemonResponse::new(Ok::<DaemonResponseData, String>(DaemonResponseData::ServerConfig(
)))?); ServerConfig::default()
insta::assert_snapshot!(serde_json::to_string(&DaemonResponse::new(Ok::< ))))?
DaemonResponseData, );
String,
>(
DaemonResponseData::ServerConfig(ServerConfig::default())
)))?);
Ok(()) Ok(())
} }

View file

@ -1,5 +1,5 @@
--- ---
source: burrow/src/daemon/command.rs source: burrow/src/daemon/command.rs
expression: "serde_json::to_string(&DaemonCommand::Start(DaemonStartOptions {\n tun: TunOptions { seek_utun: true, ..TunOptions::default() },\n })).unwrap()" expression: "serde_json::to_string(&DaemonCommand::ServerInfo).unwrap()"
--- ---
{"Start":{"tun":{"name":null,"no_pi":false,"tun_excl":false,"seek_utun":true,"address":null}}} "ServerInfo"

View file

@ -1,5 +1,5 @@
--- ---
source: burrow/src/daemon/command.rs source: burrow/src/daemon/command.rs
expression: "serde_json::to_string(&DaemonCommand::ServerInfo).unwrap()" expression: "serde_json::to_string(&DaemonCommand::Stop).unwrap()"
--- ---
"ServerInfo" "Stop"

View file

@ -1,5 +1,5 @@
--- ---
source: burrow/src/daemon/command.rs source: burrow/src/daemon/command.rs
expression: "serde_json::to_string(&DaemonCommand::Stop).unwrap()" expression: "serde_json::to_string(&DaemonCommand::ServerConfig).unwrap()"
--- ---
"Stop" "ServerConfig"

View file

@ -1,5 +0,0 @@
---
source: burrow/src/daemon/command.rs
expression: "serde_json::to_string(&DaemonCommand::ServerConfig).unwrap()"
---
"ServerConfig"

View file

@ -2,4 +2,4 @@
source: burrow/src/daemon/command.rs source: burrow/src/daemon/command.rs
expression: "serde_json::to_string(&DaemonCommand::Start(DaemonStartOptions::default())).unwrap()" expression: "serde_json::to_string(&DaemonCommand::Start(DaemonStartOptions::default())).unwrap()"
--- ---
{"Start":{"tun":{"name":null,"no_pi":false,"tun_excl":false,"seek_utun":false,"address":null}}} {"Start":{"tun":{"name":null,"no_pi":null,"tun_excl":null}}}

View file

@ -2,4 +2,4 @@
source: burrow/src/daemon/response.rs source: burrow/src/daemon/response.rs
expression: "serde_json::to_string(&DaemonResponse::new(Ok::<DaemonResponseData,\n String>(DaemonResponseData::ServerConfig(ServerConfig::default()))))?" expression: "serde_json::to_string(&DaemonResponse::new(Ok::<DaemonResponseData,\n String>(DaemonResponseData::ServerConfig(ServerConfig::default()))))?"
--- ---
{"result":{"Ok":{"ServerConfig":{"address":"10.13.13.2","name":null,"mtu":null}}},"id":0} {"result":{"Ok":{"ServerConfig":{"address":"10.0.0.1","name":null,"mtu":null}}},"id":0}

40
burrow/src/ensureroot.rs Normal file
View file

@ -0,0 +1,40 @@
use tracing::instrument;
// Check capabilities on Linux
#[cfg(target_os = "linux")]
#[instrument]
pub fn ensure_root() {
use caps::{has_cap, CapSet, Capability};
let cap_net_admin = Capability::CAP_NET_ADMIN;
if let Ok(has_cap) = has_cap(None, CapSet::Effective, cap_net_admin) {
if !has_cap {
eprintln!(
"This action needs the CAP_NET_ADMIN permission. Did you mean to run it as root?"
);
std::process::exit(77);
}
} else {
eprintln!("Failed to check capabilities. Please file a bug report!");
std::process::exit(71);
}
}
// Check for root user on macOS
#[cfg(target_vendor = "apple")]
#[instrument]
pub fn ensure_root() {
use nix::unistd::Uid;
let current_uid = Uid::current();
if !current_uid.is_root() {
eprintln!("This action must be run as root!");
std::process::exit(77);
}
}
#[cfg(target_family = "windows")]
#[instrument]
pub fn ensure_root() {
todo!()
}

View file

@ -1,16 +1,44 @@
pub mod ensureroot;
pub mod wireguard; pub mod wireguard;
mod daemon; use anyhow::Result;
pub use daemon::{
DaemonCommand, #[cfg(any(target_os = "linux", target_vendor = "apple"))]
DaemonResponse, use std::{
DaemonResponseData, mem,
DaemonStartOptions, os::fd::{AsRawFd, FromRawFd},
ServerInfo,
}; };
use tun::TunInterface;
// TODO Separate start and retrieve functions
mod daemon;
pub use daemon::{DaemonCommand, DaemonResponseData, DaemonStartOptions, DaemonResponse, ServerInfo};
#[cfg(target_vendor = "apple")] #[cfg(target_vendor = "apple")]
mod apple; mod apple;
#[cfg(target_vendor = "apple")] #[cfg(target_vendor = "apple")]
pub use apple::*; pub use apple::*;
#[cfg(any(target_os = "linux", target_vendor = "apple"))]
#[no_mangle]
pub extern "C" fn retrieve() -> i32 {
let iface2 = (1..100)
.filter_map(|i| {
let iface = unsafe { TunInterface::from_raw_fd(i) };
match iface.name() {
Ok(_name) => Some(iface),
Err(_) => {
mem::forget(iface);
None
}
}
})
.next();
match iface2 {
Some(iface) => iface.as_raw_fd(),
None => -1,
}
}

View file

@ -1,19 +1,23 @@
use std::mem;
#[cfg(any(target_os = "linux", target_vendor = "apple"))]
use std::os::fd::FromRawFd;
use anyhow::{Context, Result}; use anyhow::{Context, Result};
#[cfg(any(target_os = "linux", target_vendor = "apple"))] #[cfg(any(target_os = "linux", target_vendor = "apple"))]
use burrow::retrieve;
use clap::{Args, Parser, Subcommand}; use clap::{Args, Parser, Subcommand};
use tracing::instrument; use tracing::{instrument, Level};
use tracing_log::LogTracer; use tracing_log::LogTracer;
use tracing_oslog::OsLogger; use tracing_oslog::OsLogger;
use tracing_subscriber::{prelude::*, EnvFilter, FmtSubscriber}; use tracing_subscriber::{prelude::*, FmtSubscriber, EnvFilter};
#[cfg(any(target_os = "linux", target_vendor = "apple"))] #[cfg(any(target_os = "linux", target_vendor = "apple"))]
use tun::TunInterface; use tun::TunInterface;
mod daemon; mod daemon;
mod wireguard; mod wireguard;
use daemon::{DaemonClient, DaemonCommand, DaemonStartOptions}; use daemon::{DaemonClient, DaemonCommand, DaemonStartOptions};
use tun::TunOptions;
use crate::daemon::DaemonResponseData; use crate::daemon::DaemonResponseData;
#[derive(Parser)] #[derive(Parser)]
@ -61,9 +65,7 @@ struct DaemonArgs {}
async fn try_start() -> Result<()> { async fn try_start() -> Result<()> {
let mut client = DaemonClient::new().await?; let mut client = DaemonClient::new().await?;
client client
.send_command(DaemonCommand::Start(DaemonStartOptions { .send_command(DaemonCommand::Start(DaemonStartOptions::default()))
tun: TunOptions::new().address("10.13.13.2"),
}))
.await .await
.map(|_| ()) .map(|_| ())
} }
@ -85,8 +87,9 @@ async fn try_retrieve() -> Result<()> {
} }
} }
let iface2 = TunInterface::retrieve().ok_or(anyhow::anyhow!("No interface found"))?; burrow::ensureroot::ensure_root();
tracing::info!("{:?}", iface2); let iface2 = retrieve();
tracing::info!("{}", iface2);
Ok(()) Ok(())
} }
@ -101,10 +104,9 @@ async fn initialize_tracing() -> Result<()> {
FmtSubscriber::builder() FmtSubscriber::builder()
.with_line_number(true) .with_line_number(true)
.with_env_filter(EnvFilter::from_default_env()) .with_env_filter(EnvFilter::from_default_env())
.finish(), .finish()
); );
tracing::subscriber::set_global_default(logger) tracing::subscriber::set_global_default(logger).context("Failed to set the global tracing subscriber")?;
.context("Failed to set the global tracing subscriber")?;
} }
} }
@ -119,7 +121,7 @@ async fn try_stop() -> Result<()> {
} }
#[cfg(any(target_os = "linux", target_vendor = "apple"))] #[cfg(any(target_os = "linux", target_vendor = "apple"))]
async fn try_serverinfo() -> Result<()> { async fn try_serverinfo() -> Result<()>{
let mut client = DaemonClient::new().await?; let mut client = DaemonClient::new().await?;
let res = client.send_command(DaemonCommand::ServerInfo).await?; let res = client.send_command(DaemonCommand::ServerInfo).await?;
match res.result { match res.result {
@ -129,9 +131,7 @@ async fn try_serverinfo() -> Result<()> {
Ok(DaemonResponseData::None) => { Ok(DaemonResponseData::None) => {
println!("Server not started.") println!("Server not started.")
} }
Ok(res) => { Ok(res) => {println!("Unexpected Response: {:?}", res)}
println!("Unexpected Response: {:?}", res)
}
Err(e) => { Err(e) => {
println!("Error when retrieving from server: {}", e) println!("Error when retrieving from server: {}", e)
} }
@ -140,7 +140,7 @@ async fn try_serverinfo() -> Result<()> {
} }
#[cfg(any(target_os = "linux", target_vendor = "apple"))] #[cfg(any(target_os = "linux", target_vendor = "apple"))]
async fn try_serverconfig() -> Result<()> { async fn try_serverconfig() -> Result<()>{
let mut client = DaemonClient::new().await?; let mut client = DaemonClient::new().await?;
let res = client.send_command(DaemonCommand::ServerConfig).await?; let res = client.send_command(DaemonCommand::ServerConfig).await?;
match res.result { match res.result {
@ -150,9 +150,7 @@ async fn try_serverconfig() -> Result<()> {
Ok(DaemonResponseData::None) => { Ok(DaemonResponseData::None) => {
println!("Server not started.") println!("Server not started.")
} }
Ok(res) => { Ok(res) => {println!("Unexpected Response: {:?}", res)}
println!("Unexpected Response: {:?}", res)
}
Err(e) => { Err(e) => {
println!("Error when retrieving from server: {}", e) println!("Error when retrieving from server: {}", e)
} }
@ -203,8 +201,12 @@ async fn main() -> Result<()> {
try_stop().await?; try_stop().await?;
} }
Commands::Daemon(_) => daemon::daemon_main().await?, Commands::Daemon(_) => daemon::daemon_main().await?,
Commands::ServerInfo => try_serverinfo().await?, Commands::ServerInfo => {
Commands::ServerConfig => try_serverconfig().await?, try_serverinfo().await?
}
Commands::ServerConfig => {
try_serverconfig().await?
}
} }
Ok(()) Ok(())

View file

@ -1,15 +1,23 @@
use std::{net::IpAddr, sync::Arc, time::Duration}; use std::{net::IpAddr, rc::Rc};
use std::sync::Arc;
use std::time::Duration;
use anyhow::Error; use anyhow::Error;
use async_trait::async_trait; use async_trait::async_trait;
use fehler::throws; use fehler::throws;
use futures::{future::join_all, FutureExt};
use ip_network_table::IpNetworkTable; use ip_network_table::IpNetworkTable;
use tokio::{sync::RwLock, task::JoinHandle, time::timeout}; use log::log;
use tracing::{debug, error}; use tokio::{
join,
sync::{Mutex, RwLock},
task::{self, JoinHandle},
};
use tun::tokio::TunInterface; use tun::tokio::TunInterface;
use futures::future::join_all;
use futures::FutureExt;
use tokio::time::timeout;
use super::{noise::Tunnel, Peer, PeerPcb}; use super::{noise::Tunnel, pcb, Peer, PeerPcb};
#[async_trait] #[async_trait]
pub trait PacketInterface { pub trait PacketInterface {
@ -29,7 +37,7 @@ impl PacketInterface for tun::tokio::TunInterface {
} }
struct IndexedPcbs { struct IndexedPcbs {
pcbs: Vec<Arc<PeerPcb>>, pcbs: Vec<Arc<RwLock<PeerPcb>>>,
allowed_ips: IpNetworkTable<usize>, allowed_ips: IpNetworkTable<usize>,
} }
@ -46,7 +54,7 @@ impl IndexedPcbs {
for allowed_ip in pcb.allowed_ips.iter() { for allowed_ip in pcb.allowed_ips.iter() {
self.allowed_ips.insert(allowed_ip.clone(), idx); self.allowed_ips.insert(allowed_ip.clone(), idx);
} }
self.pcbs.insert(idx, Arc::new(pcb)); self.pcbs.insert(idx, Arc::new(RwLock::new(pcb)));
} }
pub fn find(&self, addr: IpAddr) -> Option<usize> { pub fn find(&self, addr: IpAddr) -> Option<usize> {
@ -55,7 +63,7 @@ impl IndexedPcbs {
} }
pub async fn connect(&self, idx: usize, handle: JoinHandle<()>) { pub async fn connect(&self, idx: usize, handle: JoinHandle<()>) {
self.pcbs[idx].handle.write().await.replace(handle); self.pcbs[idx].write().await.handle = Some(handle);
} }
} }
@ -69,124 +77,115 @@ impl FromIterator<PeerPcb> for IndexedPcbs {
} }
pub struct Interface { pub struct Interface {
tun: Option<Arc<RwLock<TunInterface>>>, tun: Arc<RwLock<TunInterface>>,
pcbs: Arc<IndexedPcbs>, pcbs: Arc<IndexedPcbs>,
} }
impl Interface { impl Interface {
#[throws] #[throws]
pub fn new<I: IntoIterator<Item = Peer>>(peers: I) -> Self { pub fn new<I: IntoIterator<Item = Peer>>(tun: TunInterface, peers: I) -> Self {
let pcbs: IndexedPcbs = peers let mut pcbs: IndexedPcbs = peers
.into_iter() .into_iter()
.map(|peer| PeerPcb::new(peer)) .map(|peer| PeerPcb::new(peer))
.collect::<Result<_, _>>()?; .collect::<Result<_, _>>()?;
let tun = Arc::new(RwLock::new(tun));
let pcbs = Arc::new(pcbs); let pcbs = Arc::new(pcbs);
Self { pcbs, tun: None } Self { tun, pcbs }
} }
pub fn set_tun(&mut self, tun: Arc<RwLock<TunInterface>>) { pub async fn run(self) {
self.tun = Some(tun);
}
pub async fn run(&self) -> anyhow::Result<()> {
debug!("RUN: starting interface");
let pcbs = self.pcbs.clone(); let pcbs = self.pcbs.clone();
let tun = self let tun = self.tun.clone();
.tun
.clone()
.ok_or(anyhow::anyhow!("tun interface does not exist"))?;
log::info!("starting interface"); log::info!("starting interface");
let outgoing = async move { let outgoing = async move {
loop { loop {
// tracing::debug!("starting loop..."); log::debug!("starting loop...");
let mut buf = [0u8; 3000]; let mut buf = [0u8; 3000];
let src = { let src = {
let src = match timeout( log::debug!("awaiting read...");
Duration::from_millis(10), let src = match timeout(Duration::from_secs(2), tun.write().await.recv(&mut buf[..])).await {
tun.read().await.recv(&mut buf[..]),
)
.await
{
Ok(Ok(len)) => &buf[..len], Ok(Ok(len)) => &buf[..len],
Ok(Err(e)) => { Ok(Err(e)) => {continue}
error!("failed to read from interface: {}", e); Err(_would_block) => {
continue continue
} }
Err(_would_block) => continue,
}; };
debug!("read {} bytes from interface", src.len()); log::debug!("read {} bytes from interface", src.len());
debug!("bytes: {:?}", src); log::debug!("bytes: {:?}", src);
src src
}; };
let dst_addr = match Tunnel::dst_address(src) { let dst_addr = match Tunnel::dst_address(src) {
Some(addr) => addr, Some(addr) => addr,
None => { None => {
tracing::debug!("no destination found"); log::debug!("no destination found");
continue continue
} },
}; };
tracing::debug!("dst_addr: {}", dst_addr); log::debug!("dst_addr: {}", dst_addr);
let Some(idx) = pcbs.find(dst_addr) else { let Some(idx) = pcbs.find(dst_addr) else {
continue continue
}; };
tracing::debug!("found peer:{}", idx); log::debug!("found peer:{}", idx);
match pcbs.pcbs[idx].send(src).await { match pcbs.pcbs[idx].read().await.send(src).await {
Ok(..) => { Ok(..) => {
let addr = pcbs.pcbs[idx].endpoint; log::debug!("sent packet to peer {}", dst_addr);
tracing::debug!("sent packet to peer {}", addr);
} }
Err(e) => { Err(e) => {
log::error!("failed to send packet {}", e); log::error!("failed to send packet {}", e);
continue continue
} },
}; };
// let mut buf = [0u8; 3000];
// match pcbs.pcbs[idx].read().await.recv(&mut buf).await {
// Ok(len) => log::debug!("received {} bytes from peer {}", len, dst_addr),
// Err(e) => {
// log::error!("failed to receive packet {}", e);
// continue
// },
// }
} }
}; };
let mut tsks = vec![]; let mut tsks = vec![];
let tun = self let tun = self.tun.clone();
.tun
.clone()
.ok_or(anyhow::anyhow!("tun interface does not exist"))?;
let outgoing = tokio::task::spawn(outgoing); let outgoing = tokio::task::spawn(outgoing);
tsks.push(outgoing); tsks.push(outgoing);
debug!("preparing to spawn read tasks");
{ {
let pcbs = &self.pcbs; let pcbs = self.pcbs;
for i in 0..pcbs.pcbs.len() { for i in 0..pcbs.pcbs.len(){
debug!("spawning read task for peer {}", i); let mut pcb = pcbs.pcbs[i].clone();
let pcb = pcbs.pcbs[i].clone();
let tun = tun.clone(); let tun = tun.clone();
let tsk = async move { let tsk = async move {
if let Err(e) = pcb.open_if_closed().await { {
log::error!("failed to open pcb: {}", e); let r1 = pcb.write().await.open_if_closed().await;
return if let Err(e) = r1 {
log::error!("failed to open pcb: {}", e);
return
}
} }
let r2 = pcb.run(tun).await; let r2 = pcb.read().await.run().await;
if let Err(e) = r2 { if let Err(e) = r2 {
log::error!("failed to run pcb: {}", e); log::error!("failed to run pcb: {}", e);
return return
} else { } else {
tracing::debug!("pcb ran successfully"); log::debug!("pcb ran successfully");
} }
}; };
debug!("task made..");
tsks.push(tokio::spawn(tsk)); tsks.push(tokio::spawn(tsk));
} }
debug!("spawned read tasks"); log::debug!("spawned read tasks");
} }
debug!("preparing to join.."); log::debug!("preparing to join..");
join_all(tsks).await; join_all(tsks).await;
debug!("joined!");
Ok(())
} }
} }

View file

@ -136,6 +136,10 @@ fn aead_chacha20_open(
) -> Result<(), WireGuardError> { ) -> Result<(), WireGuardError> {
let mut nonce: [u8; 12] = [0; 12]; let mut nonce: [u8; 12] = [0; 12];
nonce[4..].copy_from_slice(&counter.to_le_bytes()); nonce[4..].copy_from_slice(&counter.to_le_bytes());
log::debug!("TAG A");
log::debug!("{:?};{:?};{:?};{:?};{}", key, data, aad, nonce, counter);
aead_chacha20_open_inner(buffer, key, nonce, data, aad) aead_chacha20_open_inner(buffer, key, nonce, data, aad)
.map_err(|_| WireGuardError::InvalidAeadTag)?; .map_err(|_| WireGuardError::InvalidAeadTag)?;
Ok(()) Ok(())
@ -679,6 +683,7 @@ impl Handshake {
aad: &mac1[0..16], aad: &mac1[0..16],
msg: packet.encrypted_cookie, msg: packet.encrypted_cookie,
}; };
log::debug!("TAG B");
let plaintext = XChaCha20Poly1305::new_from_slice(&key) let plaintext = XChaCha20Poly1305::new_from_slice(&key)
.unwrap() .unwrap()
.decrypt(packet.nonce.into(), payload) .decrypt(packet.nonce.into(), payload)

View file

@ -146,7 +146,7 @@ impl Tunnel {
// Checks the type, as well as the reserved zero fields // Checks the type, as well as the reserved zero fields
let packet_type = u32::from_le_bytes(src[0..4].try_into().unwrap()); let packet_type = u32::from_le_bytes(src[0..4].try_into().unwrap());
tracing::debug!("packet_type: {}", packet_type); log::debug!("packet_type: {}", packet_type);
Ok(match (packet_type, src.len()) { Ok(match (packet_type, src.len()) {
(HANDSHAKE_INIT, HANDSHAKE_INIT_SZ) => Packet::HandshakeInit(HandshakeInit { (HANDSHAKE_INIT, HANDSHAKE_INIT_SZ) => Packet::HandshakeInit(HandshakeInit {

View file

@ -6,6 +6,7 @@ use std::{
use aead::{generic_array::GenericArray, AeadInPlace, KeyInit}; use aead::{generic_array::GenericArray, AeadInPlace, KeyInit};
use chacha20poly1305::{Key, XChaCha20Poly1305}; use chacha20poly1305::{Key, XChaCha20Poly1305};
use log::log;
use parking_lot::Mutex; use parking_lot::Mutex;
use rand_core::{OsRng, RngCore}; use rand_core::{OsRng, RngCore};
use ring::constant_time::verify_slices_are_equal; use ring::constant_time::verify_slices_are_equal;
@ -173,14 +174,14 @@ impl RateLimiter {
dst: &'b mut [u8], dst: &'b mut [u8],
) -> Result<Packet<'a>, TunnResult<'b>> { ) -> Result<Packet<'a>, TunnResult<'b>> {
let packet = Tunnel::parse_incoming_packet(src)?; let packet = Tunnel::parse_incoming_packet(src)?;
tracing::debug!("packet: {:?}", packet); log::debug!("packet: {:?}", packet);
// Verify and rate limit handshake messages only // Verify and rate limit handshake messages only
if let Packet::HandshakeInit(HandshakeInit { sender_idx, .. }) if let Packet::HandshakeInit(HandshakeInit { sender_idx, .. })
| Packet::HandshakeResponse(HandshakeResponse { sender_idx, .. }) = packet | Packet::HandshakeResponse(HandshakeResponse { sender_idx, .. }) = packet
{ {
tracing::debug!("sender_idx: {}", sender_idx); log::debug!("sender_idx: {}", sender_idx);
tracing::debug!("response: {:?}", packet); log::debug!("response: {:?}", packet);
let (msg, macs) = src.split_at(src.len() - 32); let (msg, macs) = src.split_at(src.len() - 32);
let (mac1, mac2) = macs.split_at(16); let (mac1, mac2) = macs.split_at(16);

View file

@ -253,7 +253,7 @@ impl Session {
// check the counter without running expensive decryption // check the counter without running expensive decryption
self.receiving_counter_quick_check(packet.counter)?; self.receiving_counter_quick_check(packet.counter)?;
tracing::debug!("TAG C"); log::debug!("TAG C");
let ret = { let ret = {
let mut nonce = [0u8; 12]; let mut nonce = [0u8; 12];
nonce[4..12].copy_from_slice(&packet.counter.to_le_bytes()); nonce[4..12].copy_from_slice(&packet.counter.to_le_bytes());

View file

@ -1,17 +1,21 @@
use std::{ use std::io;
cell::{Cell, RefCell}, use std::net::SocketAddr;
net::SocketAddr, use std::rc::Rc;
sync::Arc, use std::sync::Arc;
}; use std::time::Duration;
use anyhow::{anyhow, Error}; use anyhow::{anyhow, Error};
use fehler::throws; use fehler::throws;
use ip_network::IpNetwork; use ip_network::IpNetwork;
use log::log;
use rand::random; use rand::random;
use tokio::{net::UdpSocket, sync::RwLock, task::JoinHandle}; use tokio::{net::UdpSocket, task::JoinHandle};
use tun::tokio::TunInterface; use tokio::sync::{Mutex, RwLock};
use tokio::time::timeout;
use uuid::uuid;
use super::{ use super::{
iface::PacketInterface,
noise::{TunnResult, Tunnel}, noise::{TunnResult, Tunnel},
Peer, Peer,
}; };
@ -20,101 +24,107 @@ use super::{
pub struct PeerPcb { pub struct PeerPcb {
pub endpoint: SocketAddr, pub endpoint: SocketAddr,
pub allowed_ips: Vec<IpNetwork>, pub allowed_ips: Vec<IpNetwork>,
pub handle: RwLock<Option<JoinHandle<()>>>, pub handle: Option<JoinHandle<()>>,
socket: RwLock<Option<UdpSocket>>, socket: Option<UdpSocket>,
tunnel: RwLock<Tunnel>, tunnel: RwLock<Tunnel>,
} }
impl PeerPcb { impl PeerPcb {
#[throws] #[throws]
pub fn new(peer: Peer) -> Self { pub fn new(peer: Peer) -> Self {
let tunnel = RwLock::new( let tunnel = RwLock::new(Tunnel::new(peer.private_key, peer.public_key, peer.preshared_key, None, 1, None)
Tunnel::new( .map_err(|s| anyhow::anyhow!("{}", s))?);
peer.private_key,
peer.public_key,
peer.preshared_key,
None,
1,
None,
)
.map_err(|s| anyhow::anyhow!("{}", s))?,
);
Self { Self {
endpoint: peer.endpoint, endpoint: peer.endpoint,
allowed_ips: peer.allowed_ips, allowed_ips: peer.allowed_ips,
handle: RwLock::new(None), handle: None,
socket: RwLock::new(None), socket: None,
tunnel, tunnel,
} }
} }
pub async fn open_if_closed(&self) -> Result<(), Error> { pub async fn open_if_closed(&mut self) -> Result<(), Error> {
if self.socket.read().await.is_none() { if self.socket.is_none() {
let socket = UdpSocket::bind("0.0.0.0:0").await?; let socket = UdpSocket::bind("0.0.0.0:0").await?;
socket.connect(self.endpoint).await?; socket.connect(self.endpoint).await?;
self.socket.write().await.replace(socket); self.socket = Some(socket);
} }
Ok(()) Ok(())
} }
pub async fn run(&self, tun_interface: Arc<RwLock<TunInterface>>) -> Result<(), Error> { pub async fn run(&self) -> Result<(), Error> {
tracing::debug!("starting read loop for pcb... for {:?}", &self); let mut buf = [0u8; 3000];
let rid: i32 = random(); log::debug!("starting read loop for pcb...");
let mut buf: [u8; 3000] = [0u8; 3000];
tracing::debug!("start read loop {}", rid);
loop { loop {
tracing::debug!("{}: waiting for packet", rid); tracing::debug!("waiting for packet");
let guard = self.socket.read().await; let len = self.recv(&mut buf).await?;
let Some(socket) = guard.as_ref() else { tracing::debug!("received {} bytes", len);
}
}
pub async fn recv(&self, buf: &mut [u8]) -> Result<usize, Error> {
log::debug!("starting read loop for pcb... for {:?}", &self);
let rid: i32 = random();
log::debug!("start read loop {}", rid);
loop{
log::debug!("{}: waiting for packet", rid);
let Some(socket) = &self.socket else {
continue continue
}; };
let mut res_buf = [0; 1500]; let mut res_buf = [0;1500];
// tracing::debug!("{} : waiting for readability on {:?}", rid, socket); log::debug!("{} : waiting for readability on {:?}", rid, socket);
let len = match socket.recv(&mut res_buf).await { match timeout(Duration::from_secs(2), socket.readable()).await {
Ok(l) => l,
Err(e) => { Err(e) => {
log::error!("{}: error reading from socket: {:?}", rid, e); log::debug!("{}: timeout waiting for readability on {:?}", rid, e);
continue continue
} }
Ok(Err(e)) => {
log::debug!("{}: error waiting for readability on {:?}", rid, e);
continue
}
Ok(Ok(_)) => {}
};
log::debug!("{}: readable!", rid);
let Ok(len) = socket.try_recv(&mut res_buf) else {
continue
}; };
let mut res_dat = &res_buf[..len]; let mut res_dat = &res_buf[..len];
tracing::debug!("{}: Decapsulating {} bytes", rid, len); tracing::debug!("{}: Decapsulating {} bytes", rid, len);
tracing::debug!("{:?}", &res_dat); tracing::debug!("{:?}", &res_dat);
loop { loop {
match self match self.tunnel.write().await.decapsulate(None, res_dat, &mut buf[..]) {
.tunnel TunnResult::Done => {
.write() break;
.await }
.decapsulate(None, res_dat, &mut buf[..])
{
TunnResult::Done => break,
TunnResult::Err(e) => { TunnResult::Err(e) => {
tracing::error!(message = "Decapsulate error", error = ?e); tracing::error!(message = "Decapsulate error", error = ?e);
break break;
} }
TunnResult::WriteToNetwork(packet) => { TunnResult::WriteToNetwork(packet) => {
tracing::debug!("WriteToNetwork: {:?}", packet); tracing::debug!("WriteToNetwork: {:?}", packet);
self.open_if_closed().await?;
socket.send(packet).await?; socket.send(packet).await?;
tracing::debug!("WriteToNetwork done"); tracing::debug!("WriteToNetwork done");
res_dat = &[]; res_dat = &[];
continue continue;
} }
TunnResult::WriteToTunnelV4(packet, addr) => { TunnResult::WriteToTunnelV4(packet, addr) => {
tracing::debug!("WriteToTunnelV4: {:?}, {:?}", packet, addr); tracing::debug!("WriteToTunnelV4: {:?}, {:?}", packet, addr);
tun_interface.read().await.send(packet).await?; continue;
break
}
TunnResult::WriteToTunnelV6(packet, addr) => {
tracing::debug!("WriteToTunnelV6: {:?}, {:?}", packet, addr);
tun_interface.read().await.send(packet).await?;
break
} }
e => panic!("Unexpected result from decapsulate: {:?}", e),
} }
} }
return Ok(len)
} }
} }
pub async fn socket(&mut self) -> Result<&UdpSocket, Error> {
self.open_if_closed().await?;
Ok(self.socket.as_ref().expect("socket was just opened"))
}
pub async fn send(&self, src: &[u8]) -> Result<(), Error> { pub async fn send(&self, src: &[u8]) -> Result<(), Error> {
let mut dst_buf = [0u8; 3000]; let mut dst_buf = [0u8; 3000];
match self.tunnel.write().await.encapsulate(src, &mut dst_buf[..]) { match self.tunnel.write().await.encapsulate(src, &mut dst_buf[..]) {
@ -123,12 +133,7 @@ impl PeerPcb {
tracing::error!(message = "Encapsulate error", error = ?e) tracing::error!(message = "Encapsulate error", error = ?e)
} }
TunnResult::WriteToNetwork(packet) => { TunnResult::WriteToNetwork(packet) => {
self.open_if_closed().await?; let socket = &self.socket.as_ref().ok_or(anyhow!("socket not open"))?;
let handle = self.socket.read().await;
let Some(socket) = handle.as_ref() else {
tracing::error!("No socket for peer");
return Ok(())
};
tracing::debug!("Our Encapsulated packet: {:?}", packet); tracing::debug!("Our Encapsulated packet: {:?}", packet);
socket.send(packet).await?; socket.send(packet).await?;
} }

View file

@ -1,5 +1,7 @@
use std::{fmt, net::SocketAddr}; use std::{fmt, net::SocketAddr};
use anyhow::Error;
use fehler::throws;
use ip_network::IpNetwork; use ip_network::IpNetwork;
use x25519_dalek::{PublicKey, StaticSecret}; use x25519_dalek::{PublicKey, StaticSecret};
@ -8,7 +10,7 @@ pub struct Peer {
pub private_key: StaticSecret, pub private_key: StaticSecret,
pub public_key: PublicKey, pub public_key: PublicKey,
pub allowed_ips: Vec<IpNetwork>, pub allowed_ips: Vec<IpNetwork>,
pub preshared_key: Option<[u8; 32]>, pub preshared_key: Option<[u8; 32]>
} }
impl fmt::Debug for Peer { impl fmt::Debug for Peer {

View file

@ -26,7 +26,7 @@ async fn generate(out_dir: &std::path::Path) -> anyhow::Result<()> {
println!("cargo:rerun-if-changed={}", binary_path.to_str().unwrap()); println!("cargo:rerun-if-changed={}", binary_path.to_str().unwrap());
if let (Ok(..), Ok(..)) = (File::open(&bindings_path), File::open(&binary_path)) { if let (Ok(..), Ok(..)) = (File::open(&bindings_path), File::open(&binary_path)) {
return Ok(()); return Ok(())
}; };
let archive = download(out_dir) let archive = download(out_dir)

View file

@ -5,48 +5,28 @@ use fehler::throws;
use super::TunInterface; use super::TunInterface;
#[derive(Debug, Clone, Default)] #[derive(Debug, Clone, Default)]
#[cfg_attr( #[cfg_attr(feature = "serde", derive(serde::Serialize, serde::Deserialize, schemars::JsonSchema))]
feature = "serde",
derive(serde::Serialize, serde::Deserialize, schemars::JsonSchema)
)]
pub struct TunOptions { pub struct TunOptions {
/// (Windows + Linux) Name the tun interface. /// (Windows + Linux) Name the tun interface.
pub name: Option<String>, pub(crate) name: Option<String>,
/// (Linux) Don't include packet information. /// (Linux) Don't include packet information.
pub no_pi: bool, pub(crate) no_pi: Option<()>,
/// (Linux) Avoid opening an existing persistant device. /// (Linux) Avoid opening an existing persistant device.
pub tun_excl: bool, pub(crate) tun_excl: Option<()>,
/// (Linux) The IP address of the tun interface.
pub address: Option<String>,
} }
impl TunOptions { impl TunOptions {
pub fn new() -> Self { pub fn new() -> Self { Self::default() }
Self::default()
}
pub fn name(mut self, name: &str) -> Self { pub fn name(mut self, name: &str) -> Self {
self.name = Some(name.to_owned()); self.name = Some(name.to_owned());
self self
} }
pub fn no_pi(mut self, enable: bool) -> Self { pub fn no_pi(mut self, enable: bool) { self.no_pi = enable.then_some(()); }
self.no_pi = enable;
self
}
pub fn tun_excl(mut self, enable: bool) -> Self { pub fn tun_excl(mut self, enable: bool) { self.tun_excl = enable.then_some(()); }
self.tun_excl = enable;
self
}
pub fn address(mut self, address: impl ToString) -> Self {
self.address = Some(address.to_string());
self
}
#[throws] #[throws]
pub fn open(self) -> TunInterface { pub fn open(self) -> TunInterface { TunInterface::new_with_options(self)? }
TunInterface::new_with_options(self)?
}
} }

View file

@ -5,14 +5,15 @@ use tracing::instrument;
#[derive(Debug)] #[derive(Debug)]
pub struct TunInterface { pub struct TunInterface {
pub inner: AsyncFd<crate::TunInterface>, inner: AsyncFd<crate::TunInterface>,
} }
impl TunInterface { impl TunInterface {
#[instrument] #[instrument]
pub fn new(mut tun: crate::TunInterface) -> io::Result<Self> { pub fn new(tun: crate::TunInterface) -> io::Result<Self> {
tun.set_nonblocking(true)?; Ok(Self {
Ok(Self { inner: AsyncFd::new(tun)? }) inner: AsyncFd::new(tun)?,
})
} }
#[instrument] #[instrument]
@ -26,17 +27,38 @@ impl TunInterface {
} }
} }
#[instrument] // #[instrument]
pub async fn recv(&self, buf: &mut [u8]) -> io::Result<usize> { pub async fn recv(&mut self, buf: &mut [u8]) -> io::Result<usize> {
loop { loop {
let mut guard = self.inner.readable().await?; log::debug!("TunInterface receiving...");
match guard.try_io(|inner| inner.get_ref().recv(buf)) { let mut guard = self.inner.readable_mut().await?;
Ok(result) => return result, log::debug!("Got! readable_mut");
match guard.try_io(|inner| {
// log::debug!("Got! {:#?}", inner);
let raw_ref = (*inner).get_mut();
// log::debug!("Got mut ref! {:#?}", raw_ref);
let recved = raw_ref.recv(buf);
// log::debug!("Got recved! {:#?}", recved);
recved
}) {
Ok(result) => {
log::debug!("HORRAY");
return result
},
Err(_would_block) => { Err(_would_block) => {
tracing::debug!("WouldBlock"); log::debug!("WouldBlock");
continue continue
} },
} }
} }
} }
#[instrument]
pub async fn try_recv(&mut self, buf: &mut [u8]) -> io::Result<usize> {
let mut guard = self.inner.readable_mut().await?;
match guard.try_io(|inner| (*inner).get_mut().recv(buf)) {
Ok(result) => Ok(result.unwrap_or_default()),
Err(_would_block) => Err(io::Error::new(io::ErrorKind::WouldBlock, "WouldBlock")),
}
}
} }

View file

@ -1,8 +1,8 @@
use std::{ use std::{
io::{Error, IoSlice}, io::{Error, IoSlice},
mem::{self, ManuallyDrop}, mem,
net::{Ipv4Addr, SocketAddrV4}, net::{Ipv4Addr, SocketAddrV4},
os::fd::{AsRawFd, FromRawFd, RawFd}, os::fd::{AsRawFd, RawFd},
}; };
use byteorder::{ByteOrder, NetworkEndian}; use byteorder::{ByteOrder, NetworkEndian};
@ -11,14 +11,13 @@ use libc::{c_char, iovec, writev, AF_INET, AF_INET6};
use socket2::{Domain, SockAddr, Socket, Type}; use socket2::{Domain, SockAddr, Socket, Type};
use tracing::{self, instrument}; use tracing::{self, instrument};
pub mod kern_control; mod kern_control;
pub mod sys; mod sys;
use kern_control::SysControlSocket; use kern_control::SysControlSocket;
pub use super::queue::TunQueue; pub use super::queue::TunQueue;
use super::{ifname_to_string, string_to_ifname}; use super::{ifname_to_string, string_to_ifname, TunOptions};
use crate::TunOptions;
#[derive(Debug)] #[derive(Debug)]
pub struct TunInterface { pub struct TunInterface {
@ -34,42 +33,8 @@ impl TunInterface {
#[throws] #[throws]
#[instrument] #[instrument]
pub fn new_with_options(options: TunOptions) -> TunInterface { pub fn new_with_options(_: TunOptions) -> TunInterface {
let ti = TunInterface::connect(0)?; TunInterface::connect(0)?
ti.configure(options)?;
ti
}
pub fn retrieve() -> Option<TunInterface> {
(3..100)
.filter_map(|fd| unsafe {
let peer_addr = socket2::SockAddr::init(|storage, len| {
*len = mem::size_of::<sys::sockaddr_ctl>() as u32;
libc::getpeername(fd, storage as *mut _, len);
Ok(())
})
.map(|(_, addr)| (fd, addr));
peer_addr.ok()
})
.filter(|(_fd, addr)| {
let ctl_addr = unsafe { &*(addr.as_ptr() as *const libc::sockaddr_ctl) };
addr.family() == libc::AF_SYSTEM as u8
&& ctl_addr.ss_sysaddr == libc::AF_SYS_CONTROL as u16
})
.map(|(fd, _)| {
let socket = unsafe { socket2::Socket::from_raw_fd(fd) };
TunInterface { socket }
})
.next()
}
#[throws]
fn configure(&self, options: TunOptions) {
if let Some(addr) = options.address {
if let Ok(addr) = addr.parse() {
self.set_ipv4_addr(addr)?;
}
}
} }
#[throws] #[throws]

View file

@ -2,11 +2,20 @@ use std::mem;
use libc::{c_char, c_int, c_short, c_uint, c_ulong, sockaddr}; use libc::{c_char, c_int, c_short, c_uint, c_ulong, sockaddr};
pub use libc::{ pub use libc::{
c_void, sockaddr_ctl, sockaddr_in, socklen_t, AF_SYSTEM, AF_SYS_CONTROL, IFNAMSIZ, c_void,
sockaddr_ctl,
sockaddr_in,
socklen_t,
AF_SYSTEM,
AF_SYS_CONTROL,
IFNAMSIZ,
SYSPROTO_CONTROL, SYSPROTO_CONTROL,
}; };
use nix::{ use nix::{
ioctl_read_bad, ioctl_readwrite, ioctl_write_ptr_bad, request_code_readwrite, ioctl_read_bad,
ioctl_readwrite,
ioctl_write_ptr_bad,
request_code_readwrite,
request_code_write, request_code_write,
}; };

View file

@ -26,9 +26,7 @@ pub struct TunInterface {
impl TunInterface { impl TunInterface {
#[throws] #[throws]
#[instrument] #[instrument]
pub fn new() -> TunInterface { pub fn new() -> TunInterface { Self::new_with_options(TunOptions::new())? }
Self::new_with_options(TunOptions::new())?
}
#[throws] #[throws]
#[instrument] #[instrument]
@ -214,7 +212,5 @@ impl TunInterface {
#[throws] #[throws]
#[instrument] #[instrument]
pub fn send(&self, buf: &[u8]) -> usize { pub fn send(&self, buf: &[u8]) -> usize { self.socket.send(buf)? }
self.socket.send(buf)?
}
} }

View file

@ -1,11 +1,12 @@
use std::{ use std::{
io::{Error, Read}, io::{Error, Read},
mem::MaybeUninit,
os::fd::{AsRawFd, FromRawFd, IntoRawFd, RawFd}, os::fd::{AsRawFd, FromRawFd, IntoRawFd, RawFd},
}; };
use tracing::instrument; use tracing::instrument;
use super::TunOptions;
mod queue; mod queue;
#[cfg(target_vendor = "apple")] #[cfg(target_vendor = "apple")]
@ -39,26 +40,21 @@ impl IntoRawFd for TunInterface {
} }
} }
unsafe fn assume_init(buf: &[MaybeUninit<u8>]) -> &[u8] {
&*(buf as *const [MaybeUninit<u8>] as *const [u8])
}
impl TunInterface { impl TunInterface {
#[throws] #[throws]
#[instrument] #[instrument]
pub fn recv(&self, buf: &mut [u8]) -> usize { pub fn recv(&mut self, buf: &mut [u8]) -> usize {
// Use IoVec to read directly into target buffer // there might be a more efficient way to implement this
let mut tmp_buf = [MaybeUninit::uninit(); 1500]; let tmp_buf = &mut [0u8; 1500];
let len = self.socket.recv(&mut tmp_buf)?; let len = self.socket.read(tmp_buf)?;
let result_buf = unsafe { assume_init(&tmp_buf[4..len]) }; buf[..len-4].copy_from_slice(&tmp_buf[4..len]);
buf[..len - 4].copy_from_slice(&result_buf); len-4
len - 4
} }
#[throws] #[throws]
#[instrument] #[instrument]
pub fn set_nonblocking(&mut self, nb: bool) { pub fn set_timeout(&self, timeout: Option<std::time::Duration>) {
self.socket.set_nonblocking(nb)?; self.socket.set_read_timeout(timeout)?;
} }
} }

View file

@ -25,9 +25,7 @@ impl Debug for TunInterface {
impl TunInterface { impl TunInterface {
#[throws] #[throws]
pub fn new() -> TunInterface { pub fn new() -> TunInterface { Self::new_with_options(TunOptions::new())? }
Self::new_with_options(TunOptions::new())?
}
#[throws] #[throws]
pub(crate) fn new_with_options(options: TunOptions) -> TunInterface { pub(crate) fn new_with_options(options: TunOptions) -> TunInterface {
@ -39,18 +37,17 @@ impl TunInterface {
if handle.is_null() { if handle.is_null() {
unsafe { GetLastError() }.ok()? unsafe { GetLastError() }.ok()?
} }
TunInterface { handle, name: name_owned } TunInterface {
handle,
name: name_owned,
}
} }
pub fn name(&self) -> String { pub fn name(&self) -> String { self.name.clone() }
self.name.clone()
}
} }
impl Drop for TunInterface { impl Drop for TunInterface {
fn drop(&mut self) { fn drop(&mut self) { unsafe { sys::WINTUN.WintunCloseAdapter(self.handle) } }
unsafe { sys::WINTUN.WintunCloseAdapter(self.handle) }
}
} }
pub(crate) mod sys { pub(crate) mod sys {

View file

@ -5,9 +5,7 @@ use tun::TunInterface;
#[test] #[test]
#[throws] #[throws]
fn test_create() { fn test_create() { TunInterface::new()?; }
TunInterface::new()?;
}
#[test] #[test]
#[throws] #[throws]