checkpoint

This commit is contained in:
Jett Chen 2023-12-10 03:44:31 +08:00
parent 17610ff90d
commit 94233874e6
39 changed files with 490 additions and 336 deletions

View file

@ -113,7 +113,7 @@ final class BurrowIpc {
return data return data
} }
func request<U: Decodable>(_ request: Request, type: U.Type) async throws -> U { func request<U: Decodable>(_ request: any Request, type: U.Type) async throws -> U {
do { do {
var data: Data = try JSONEncoder().encode(request) var data: Data = try JSONEncoder().encode(request)
data.append(contentsOf: [10]) data.append(contentsOf: [10])

View file

@ -7,16 +7,40 @@ enum BurrowError: Error {
case resultIsNone case resultIsNone
} }
protocol Request: Codable { protocol Request: Codable where T: Codable{
associatedtype T
var id: UInt { get set } var id: UInt { get set }
var command: String { get set } var command: T { get set }
} }
struct BurrowRequest: Request { struct BurrowSingleCommand: Request {
var id: UInt var id: UInt
var command: String var command: String
} }
struct BurrowRequest<T>: Request where T: Codable{
var id: UInt
var command: T
}
struct BurrowStartRequest: Codable {
struct TunOptions: Codable{
let name: String?
let no_pi: Bool
let tun_excl: Bool
let seek_utun: Int?
let address: String?
}
struct StartOptions: Codable{
let tun: TunOptions
}
let Start: StartOptions
}
func start_req_fd(id: UInt, fd: Int) -> BurrowRequest<BurrowStartRequest> {
return BurrowRequest(id: id, command: BurrowStartRequest(Start: BurrowStartRequest.StartOptions(tun: BurrowStartRequest.TunOptions(name: nil, no_pi: false, tun_excl: false, seek_utun: fd, address: nil))))
}
struct Response<T>: Decodable where T: Decodable { struct Response<T>: Decodable where T: Decodable {
var id: UInt var id: UInt
var result: T var result: T

View file

@ -17,7 +17,7 @@ class PacketTunnelProvider: NEPacketTunnelProvider {
logger.info("Started server") logger.info("Started server")
Task { Task {
do { do {
let command = BurrowRequest(id: 0, command: "ServerConfig") let command = BurrowSingleCommand(id: 0, command: "ServerConfig")
guard let data = try await client?.request(command, type: Response<BurrowResult<ServerConfigData>>.self) guard let data = try await client?.request(command, type: Response<BurrowResult<ServerConfigData>>.self)
else { else {
throw BurrowError.cantParseResult throw BurrowError.cantParseResult
@ -32,6 +32,16 @@ class PacketTunnelProvider: NEPacketTunnelProvider {
} }
try await self.setTunnelNetworkSettings(tunNs) try await self.setTunnelNetworkSettings(tunNs)
self.logger.info("Set remote tunnel address to \(tunNs.tunnelRemoteAddress)") self.logger.info("Set remote tunnel address to \(tunNs.tunnelRemoteAddress)")
// let tunFd = self.packetFlow.value(forKeyPath: "socket.fileDescriptor") as! Int;
// self.logger.info("Found File Descriptor: \(tunFd)")
let start_command = start_req_fd(id: 1, fd: 0)
guard let data = try await client?.request(start_command, type: Response<BurrowResult<String>>.self)
else {
throw BurrowError.cantParseResult
}
let encoded_startres = try JSONEncoder().encode(data.result)
self.logger.log("Received start server response: \(String(decoding: encoded_startres, as: UTF8.self))")
completionHandler(nil) completionHandler(nil)
} catch { } catch {
self.logger.error("An error occurred: \(error)") self.logger.error("An error occurred: \(error)")

View file

@ -1,15 +1,15 @@
use tracing::{debug, Subscriber};
use tracing::instrument::WithSubscriber; use tracing::instrument::WithSubscriber;
use tracing::{debug, Subscriber};
use tracing_oslog::OsLogger; use tracing_oslog::OsLogger;
use tracing_subscriber::FmtSubscriber;
use tracing_subscriber::layer::SubscriberExt; use tracing_subscriber::layer::SubscriberExt;
use tracing_subscriber::FmtSubscriber;
pub use crate::daemon::start_srv; pub use crate::daemon::start_srv;
#[no_mangle] #[no_mangle]
pub extern "C" fn initialize_oslog() { pub extern "C" fn initialize_oslog() {
let collector = tracing_subscriber::registry() let collector =
.with(OsLogger::new("com.hackclub.burrow", "backend")); tracing_subscriber::registry().with(OsLogger::new("com.hackclub.burrow", "backend"));
tracing::subscriber::set_global_default(collector).unwrap(); tracing::subscriber::set_global_default(collector).unwrap();
debug!("Initialized oslog tracing in libburrow rust FFI"); debug!("Initialized oslog tracing in libburrow rust FFI");
} }

View file

@ -17,16 +17,20 @@ pub struct DaemonStartOptions {
#[test] #[test]
fn test_daemoncommand_serialization() { fn test_daemoncommand_serialization() {
insta::assert_snapshot!(serde_json::to_string(&DaemonCommand::Start(
DaemonStartOptions::default()
))
.unwrap());
insta::assert_snapshot!( insta::assert_snapshot!(
serde_json::to_string(&DaemonCommand::Start(DaemonStartOptions::default())).unwrap() serde_json::to_string(&DaemonCommand::Start(DaemonStartOptions {
tun: TunOptions {
seek_utun: true,
..TunOptions::default()
}
}))
.unwrap()
); );
insta::assert_snapshot!( insta::assert_snapshot!(serde_json::to_string(&DaemonCommand::ServerInfo).unwrap());
serde_json::to_string(&DaemonCommand::ServerInfo).unwrap() insta::assert_snapshot!(serde_json::to_string(&DaemonCommand::Stop).unwrap());
); insta::assert_snapshot!(serde_json::to_string(&DaemonCommand::ServerConfig).unwrap())
insta::assert_snapshot!(
serde_json::to_string(&DaemonCommand::Stop).unwrap()
);
insta::assert_snapshot!(
serde_json::to_string(&DaemonCommand::ServerConfig).unwrap()
)
} }

View file

@ -1,11 +1,12 @@
use tokio::task::JoinHandle;
use tracing::{debug, info, warn};
use DaemonResponse;
use tun::tokio::TunInterface;
use crate::daemon::response::{DaemonResponseData, ServerConfig, ServerInfo};
use super::*; use super::*;
use crate::daemon::response::{DaemonResponseData, ServerConfig, ServerInfo};
use tokio::task::JoinHandle;
use tracing::field::debug;
use tracing::{debug, info, warn};
use tun::tokio::TunInterface;
use DaemonResponse;
enum RunState{ enum RunState {
Running(JoinHandle<Result<()>>), Running(JoinHandle<Result<()>>),
Idle, Idle,
} }
@ -42,34 +43,53 @@ impl DaemonInstance {
match command { match command {
DaemonCommand::Start(st) => { DaemonCommand::Start(st) => {
match self.wg_state { match self.wg_state {
RunState::Running(_) => {warn!("Got start, but tun interface already up.");} RunState::Running(_) => {
warn!("Got start, but tun interface already up.");
}
RunState::Idle => { RunState::Idle => {
debug!("Creating new TunInterface");
let tun_if = Arc::new(RwLock::new(TunInterface::new(st.tun.open()?)?)); let tun_if = Arc::new(RwLock::new(TunInterface::new(st.tun.open()?)?));
debug!("TunInterface created: {:?}", tun_if);
debug!("Setting tun_interface");
self.tun_interface = Some(tun_if.clone()); self.tun_interface = Some(tun_if.clone());
debug!("tun_interface set: {:?}", self.tun_interface);
debug!("Setting tun on wg_interface");
self.wg_interface.write().await.set_tun(tun_if); self.wg_interface.write().await.set_tun(tun_if);
debug!("tun set on wg_interface");
debug!("Cloning wg_interface");
let tmp_wg = self.wg_interface.clone(); let tmp_wg = self.wg_interface.clone();
debug!("wg_interface cloned");
debug!("Spawning run task");
let run_task = tokio::spawn(async move { let run_task = tokio::spawn(async move {
tmp_wg.read().await.run().await debug!("Running wg_interface");
let twlock = tmp_wg.read().await;
debug!("wg_interface read lock acquired");
twlock.run().await
}); });
debug!("Run task spawned: {:?}", run_task);
debug!("Setting wg_state to Running");
self.wg_state = RunState::Running(run_task); self.wg_state = RunState::Running(run_task);
debug!("wg_state set to Running");
info!("Daemon started tun interface"); info!("Daemon started tun interface");
} }
} }
Ok(DaemonResponseData::None) Ok(DaemonResponseData::None)
} }
DaemonCommand::ServerInfo => { DaemonCommand::ServerInfo => match &self.tun_interface {
match &self.tun_interface { None => Ok(DaemonResponseData::None),
None => {Ok(DaemonResponseData::None)}
Some(ti) => { Some(ti) => {
info!("{:?}", ti); info!("{:?}", ti);
Ok( Ok(DaemonResponseData::ServerInfo(ServerInfo::try_from(
DaemonResponseData::ServerInfo( ti.read().await.inner.get_ref(),
ServerInfo::try_from(ti.read().await.inner.get_ref())? )?))
)
)
}
}
} }
},
DaemonCommand::Stop => { DaemonCommand::Stop => {
if self.tun_interface.is_some() { if self.tun_interface.is_some() {
self.tun_interface = None; self.tun_interface = None;
@ -86,6 +106,7 @@ impl DaemonInstance {
} }
pub async fn run(&mut self) -> Result<()> { pub async fn run(&mut self) -> Result<()> {
tracing::info!("BEGIN");
while let Ok(command) = self.rx.recv().await { while let Ok(command) = self.rx.recv().await {
let response = self.proc_command(command).await; let response = self.proc_command(command).await;
info!("Daemon response: {:?}", response); info!("Daemon response: {:?}", response);

View file

@ -1,27 +1,26 @@
use std::net::{Ipv4Addr, SocketAddr, ToSocketAddrs}; use std::net::{IpAddr, Ipv4Addr, SocketAddr, SocketAddrV4, ToSocketAddrs};
use std::sync::Arc; use std::sync::Arc;
mod command; mod command;
mod instance; mod instance;
mod net; mod net;
mod response; mod response;
use crate::wireguard::{Interface, Peer, PublicKey, StaticSecret};
use anyhow::{Error, Result}; use anyhow::{Error, Result};
use base64::{engine::general_purpose, Engine as _}; use base64::{engine::general_purpose, Engine as _};
pub use command::{DaemonCommand, DaemonStartOptions}; pub use command::{DaemonCommand, DaemonStartOptions};
use fehler::throws; use fehler::throws;
use ip_network::{IpNetwork, Ipv4Network};
use tokio::sync::RwLock;
use instance::DaemonInstance; use instance::DaemonInstance;
use crate::wireguard::{StaticSecret, Peer, Interface, PublicKey}; use ip_network::{IpNetwork, Ipv4Network};
pub use net::DaemonClient; pub use net::DaemonClient;
use tokio::sync::RwLock;
#[cfg(target_vendor = "apple")] #[cfg(target_vendor = "apple")]
pub use net::start_srv; pub use net::start_srv;
pub use response::{DaemonResponseData, DaemonResponse, ServerInfo};
use crate::daemon::net::listen; use crate::daemon::net::listen;
pub use response::{DaemonResponse, DaemonResponseData, ServerInfo};
#[throws] #[throws]
fn parse_key(string: &str) -> [u8; 32] { fn parse_key(string: &str) -> [u8; 32] {
@ -50,7 +49,7 @@ pub async fn daemon_main() -> Result<()> {
let private_key = parse_secret_key("GNqIAOCRxjl/cicZyvkvpTklgQuUmGUIEkH7IXF/sEE=")?; let private_key = parse_secret_key("GNqIAOCRxjl/cicZyvkvpTklgQuUmGUIEkH7IXF/sEE=")?;
let public_key = parse_public_key("uy75leriJay0+oHLhRMpV+A5xAQ0hCJ+q7Ww81AOvT4=")?; let public_key = parse_public_key("uy75leriJay0+oHLhRMpV+A5xAQ0hCJ+q7Ww81AOvT4=")?;
let preshared_key = Some(parse_key("s7lx/mg+reVEMnGnqeyYOQkzD86n2+gYnx1M9ygi08k=")?); let preshared_key = Some(parse_key("s7lx/mg+reVEMnGnqeyYOQkzD86n2+gYnx1M9ygi08k=")?);
let endpoint = "wg.burrow.rs:51820".to_socket_addrs()?.next().unwrap(); let endpoint = SocketAddr::new(IpAddr::V4(Ipv4Addr::new(198, 18, 6, 180)), 51820); // DNS lookup under macos fails, somehow
let iface = Interface::new(vec![Peer { let iface = Interface::new(vec![Peer {
endpoint, endpoint,
@ -62,7 +61,22 @@ pub async fn daemon_main() -> Result<()> {
let mut inst = DaemonInstance::new(commands_rx, response_tx, Arc::new(RwLock::new(iface))); let mut inst = DaemonInstance::new(commands_rx, response_tx, Arc::new(RwLock::new(iface)));
tokio::try_join!(inst.run(), listen(commands_tx, response_rx)) tracing::info!("Starting daemon jobs...");
.map(|_| {()});
let inst_job = tokio::spawn(async move {
let res = inst.run().await;
if let Err(e) = res {
tracing::error!("Error when running instance: {}", e);
}
});
let listen_job = tokio::spawn(async move {
let res = listen(commands_tx, response_rx).await;
if let Err(e) = res {
tracing::error!("Error when listening: {}", e);
}
});
tokio::try_join!(inst_job, listen_job).map(|_| ());
Ok(()) Ok(())
} }

View file

@ -1,10 +1,12 @@
use crate::daemon::{daemon_main, DaemonClient};
use std::future::Future;
use std::thread; use std::thread;
use tokio::runtime::Runtime; use tokio::runtime::Runtime;
use tracing::error; use tracing::{error, info};
use crate::daemon::{daemon_main, DaemonClient};
#[no_mangle] #[no_mangle]
pub extern "C" fn start_srv(){ pub extern "C" fn start_srv() {
info!("Rust: Starting server");
let _handle = thread::spawn(move || { let _handle = thread::spawn(move || {
let rt = Runtime::new().unwrap(); let rt = Runtime::new().unwrap();
rt.block_on(async { rt.block_on(async {
@ -16,8 +18,11 @@ pub extern "C" fn start_srv(){
let rt = Runtime::new().unwrap(); let rt = Runtime::new().unwrap();
rt.block_on(async { rt.block_on(async {
loop { loop {
if let Ok(_) = DaemonClient::new().await{ match DaemonClient::new().await {
break Ok(_) => break,
Err(e) => {
// error!("Error when connecting to daemon: {}", e)
}
} }
} }
}); });

View file

@ -29,4 +29,3 @@ pub struct DaemonRequest {
pub id: u32, pub id: u32,
pub command: DaemonCommand, pub command: DaemonCommand,
} }

View file

@ -1,13 +1,23 @@
pub async fn listen(cmd_tx: async_channel::Sender<DaemonCommand>, rsp_rx: async_channel::Receiver<DaemonResponse>) -> Result<()> { pub async fn listen(
if !libsystemd::daemon::booted() || listen_with_systemd(cmd_tx.clone(), rsp_rx.clone()).await.is_err() { cmd_tx: async_channel::Sender<DaemonCommand>,
rsp_rx: async_channel::Receiver<DaemonResponse>,
) -> Result<()> {
if !libsystemd::daemon::booted()
|| listen_with_systemd(cmd_tx.clone(), rsp_rx.clone())
.await
.is_err()
{
unix::listen(cmd_tx, rsp_rx).await?; unix::listen(cmd_tx, rsp_rx).await?;
} }
Ok(()) Ok(())
} }
async fn listen_with_systemd(cmd_tx: async_channel::Sender<DaemonCommand>, rsp_rx: async_channel::Receiver<DaemonResponse>) -> Result<()> { async fn listen_with_systemd(
cmd_tx: async_channel::Sender<DaemonCommand>,
rsp_rx: async_channel::Receiver<DaemonResponse>,
) -> Result<()> {
let fds = libsystemd::activation::receive_descriptors(false)?; let fds = libsystemd::activation::receive_descriptors(false)?;
super::unix::listen_with_optional_fd(cmd_tx, rsp_rx,Some(fds[0].clone().into_raw_fd())).await super::unix::listen_with_optional_fd(cmd_tx, rsp_rx, Some(fds[0].clone().into_raw_fd())).await
} }
pub type DaemonClient = unix::DaemonClient; pub type DaemonClient = unix::DaemonClient;

View file

@ -1,13 +1,17 @@
use super::*; use super::*;
use anyhow::anyhow;
use log::log;
use std::hash::Hash;
use std::path::PathBuf;
use std::{ use std::{
ascii, io, os::{ ascii, io,
os::{
fd::{FromRawFd, RawFd}, fd::{FromRawFd, RawFd},
unix::net::UnixListener as StdUnixListener, unix::net::UnixListener as StdUnixListener,
}, },
path::Path}; path::Path};
use std::hash::Hash;
use std::path::PathBuf; use anyhow::Result;
use anyhow::{anyhow, Result};
use tokio::{ use tokio::{
io::{AsyncBufReadExt, AsyncWriteExt, BufReader}, io::{AsyncBufReadExt, AsyncWriteExt, BufReader},
net::{UnixListener, UnixStream}, net::{UnixListener, UnixStream},
@ -102,6 +106,7 @@ pub(crate) async fn listen_with_optional_fd(
Ok(req) => Some(req), Ok(req) => Some(req),
Err(e) => { Err(e) => {
res.result = Err(e.to_string()); res.result = Err(e.to_string());
tracing::error!("Failed to parse request: {}", e);
None None
} }
}; };
@ -115,6 +120,8 @@ pub(crate) async fn listen_with_optional_fd(
retres.push('\n'); retres.push('\n');
info!("Sending response: {}", retres); info!("Sending response: {}", retres);
write_stream.write_all(retres.as_bytes()).await.unwrap(); write_stream.write_all(retres.as_bytes()).await.unwrap();
} else {
write_stream.write_all(res.as_bytes()).await.unwrap();
} }
} }
}); });

View file

@ -1,6 +1,9 @@
use super::*; use super::*;
pub async fn listen(_cmd_tx: async_channel::Sender<DaemonCommand>, _rsp_rx: async_channel::Receiver<DaemonResponse>) -> Result<()> { pub async fn listen(
_cmd_tx: async_channel::Sender<DaemonCommand>,
_rsp_rx: async_channel::Receiver<DaemonResponse>,
) -> Result<()> {
unimplemented!("This platform does not currently support daemon mode.") unimplemented!("This platform does not currently support daemon mode.")
} }

View file

@ -7,30 +7,27 @@ use tun::TunInterface;
pub struct DaemonResponse { pub struct DaemonResponse {
// Error types can't be serialized, so this is the second best option. // Error types can't be serialized, so this is the second best option.
pub result: Result<DaemonResponseData, String>, pub result: Result<DaemonResponseData, String>,
pub id: u32 pub id: u32,
} }
impl DaemonResponse{ impl DaemonResponse {
pub fn new(result: Result<DaemonResponseData, impl ToString>) -> Self{ pub fn new(result: Result<DaemonResponseData, impl ToString>) -> Self {
Self{ Self {
result: result.map_err(|e| e.to_string()), result: result.map_err(|e| e.to_string()),
id: 0 id: 0,
} }
} }
} }
impl Into<DaemonResponse> for DaemonResponseData{ impl Into<DaemonResponse> for DaemonResponseData {
fn into(self) -> DaemonResponse{ fn into(self) -> DaemonResponse {
DaemonResponse::new(Ok::<DaemonResponseData, String>(self)) DaemonResponse::new(Ok::<DaemonResponseData, String>(self))
} }
} }
impl DaemonResponse{ impl DaemonResponse {
pub fn with_id(self, id: u32) -> Self{ pub fn with_id(self, id: u32) -> Self {
Self { Self { id, ..self }
id,
..self
}
} }
} }
@ -38,24 +35,22 @@ impl DaemonResponse{
pub struct ServerInfo { pub struct ServerInfo {
pub name: Option<String>, pub name: Option<String>,
pub ip: Option<String>, pub ip: Option<String>,
pub mtu: Option<i32> pub mtu: Option<i32>,
} }
impl TryFrom<&TunInterface> for ServerInfo{ impl TryFrom<&TunInterface> for ServerInfo {
type Error = anyhow::Error; type Error = anyhow::Error;
#[cfg(any(target_os="linux",target_vendor="apple"))] #[cfg(any(target_os = "linux", target_vendor = "apple"))]
fn try_from(server: &TunInterface) -> anyhow::Result<Self> { fn try_from(server: &TunInterface) -> anyhow::Result<Self> {
Ok( Ok(ServerInfo {
ServerInfo{
name: server.name().ok(), name: server.name().ok(),
ip: server.ipv4_addr().ok().map(|ip| ip.to_string()), ip: server.ipv4_addr().ok().map(|ip| ip.to_string()),
mtu: server.mtu().ok() mtu: server.mtu().ok(),
} })
)
} }
#[cfg(not(any(target_os="linux",target_vendor="apple")))] #[cfg(not(any(target_os = "linux", target_vendor = "apple")))]
fn try_from(server: &TunInterface) -> anyhow::Result<Self> { fn try_from(server: &TunInterface) -> anyhow::Result<Self> {
Err(anyhow!("Not implemented in this platform")) Err(anyhow!("Not implemented in this platform"))
} }
@ -65,45 +60,55 @@ impl TryFrom<&TunInterface> for ServerInfo{
pub struct ServerConfig { pub struct ServerConfig {
pub address: Option<String>, pub address: Option<String>,
pub name: Option<String>, pub name: Option<String>,
pub mtu: Option<i32> pub mtu: Option<i32>,
} }
impl Default for ServerConfig { impl Default for ServerConfig {
fn default() -> Self { fn default() -> Self {
Self{ Self {
address: Some("10.0.0.1".to_string()), // Dummy remote address address: Some("10.13.13.2".to_string()), // Dummy remote address
name: None, name: None,
mtu: None mtu: None,
} }
} }
} }
#[derive(Clone, Debug, Serialize, Deserialize, JsonSchema)] #[derive(Clone, Debug, Serialize, Deserialize, JsonSchema)]
pub enum DaemonResponseData{ pub enum DaemonResponseData {
ServerInfo(ServerInfo), ServerInfo(ServerInfo),
ServerConfig(ServerConfig), ServerConfig(ServerConfig),
None None,
} }
#[test] #[test]
fn test_response_serialization() -> anyhow::Result<()>{ fn test_response_serialization() -> anyhow::Result<()> {
insta::assert_snapshot!( insta::assert_snapshot!(serde_json::to_string(&DaemonResponse::new(Ok::<
serde_json::to_string(&DaemonResponse::new(Ok::<DaemonResponseData, String>(DaemonResponseData::None)))? DaemonResponseData,
); String,
insta::assert_snapshot!( >(
serde_json::to_string(&DaemonResponse::new(Ok::<DaemonResponseData, String>(DaemonResponseData::ServerInfo(ServerInfo{ DaemonResponseData::None
)))?);
insta::assert_snapshot!(serde_json::to_string(&DaemonResponse::new(Ok::<
DaemonResponseData,
String,
>(
DaemonResponseData::ServerInfo(ServerInfo {
name: Some("burrow".to_string()), name: Some("burrow".to_string()),
ip: None, ip: None,
mtu: Some(1500) mtu: Some(1500)
}))))? })
); )))?);
insta::assert_snapshot!( insta::assert_snapshot!(serde_json::to_string(&DaemonResponse::new(Err::<
serde_json::to_string(&DaemonResponse::new(Err::<DaemonResponseData, String>("error".to_string())))? DaemonResponseData,
); String,
insta::assert_snapshot!( >(
serde_json::to_string(&DaemonResponse::new(Ok::<DaemonResponseData, String>(DaemonResponseData::ServerConfig( "error".to_string()
ServerConfig::default() )))?);
))))? insta::assert_snapshot!(serde_json::to_string(&DaemonResponse::new(Ok::<
); DaemonResponseData,
String,
>(
DaemonResponseData::ServerConfig(ServerConfig::default())
)))?);
Ok(()) Ok(())
} }

View file

@ -1,5 +1,5 @@
--- ---
source: burrow/src/daemon/command.rs source: burrow/src/daemon/command.rs
expression: "serde_json::to_string(&DaemonCommand::ServerInfo).unwrap()" expression: "serde_json::to_string(&DaemonCommand::Start(DaemonStartOptions {\n tun: TunOptions { seek_utun: true, ..TunOptions::default() },\n })).unwrap()"
--- ---
"ServerInfo" {"Start":{"tun":{"name":null,"no_pi":false,"tun_excl":false,"seek_utun":true,"address":null}}}

View file

@ -1,5 +1,5 @@
--- ---
source: burrow/src/daemon/command.rs source: burrow/src/daemon/command.rs
expression: "serde_json::to_string(&DaemonCommand::Stop).unwrap()" expression: "serde_json::to_string(&DaemonCommand::ServerInfo).unwrap()"
--- ---
"Stop" "ServerInfo"

View file

@ -1,5 +1,5 @@
--- ---
source: burrow/src/daemon/command.rs source: burrow/src/daemon/command.rs
expression: "serde_json::to_string(&DaemonCommand::ServerConfig).unwrap()" expression: "serde_json::to_string(&DaemonCommand::Stop).unwrap()"
--- ---
"ServerConfig" "Stop"

View file

@ -0,0 +1,5 @@
---
source: burrow/src/daemon/command.rs
expression: "serde_json::to_string(&DaemonCommand::ServerConfig).unwrap()"
---
"ServerConfig"

View file

@ -2,4 +2,4 @@
source: burrow/src/daemon/command.rs source: burrow/src/daemon/command.rs
expression: "serde_json::to_string(&DaemonCommand::Start(DaemonStartOptions::default())).unwrap()" expression: "serde_json::to_string(&DaemonCommand::Start(DaemonStartOptions::default())).unwrap()"
--- ---
{"Start":{"tun":{"name":null,"no_pi":null,"tun_excl":null}}} {"Start":{"tun":{"name":null,"no_pi":false,"tun_excl":false,"seek_utun":false,"address":null}}}

View file

@ -2,4 +2,4 @@
source: burrow/src/daemon/response.rs source: burrow/src/daemon/response.rs
expression: "serde_json::to_string(&DaemonResponse::new(Ok::<DaemonResponseData,\n String>(DaemonResponseData::ServerConfig(ServerConfig::default()))))?" expression: "serde_json::to_string(&DaemonResponse::new(Ok::<DaemonResponseData,\n String>(DaemonResponseData::ServerConfig(ServerConfig::default()))))?"
--- ---
{"result":{"Ok":{"ServerConfig":{"address":"10.0.0.1","name":null,"mtu":null}}},"id":0} {"result":{"Ok":{"ServerConfig":{"address":"10.13.13.2","name":null,"mtu":null}}},"id":0}

View file

@ -23,24 +23,3 @@ mod apple;
#[cfg(target_vendor = "apple")] #[cfg(target_vendor = "apple")]
pub use apple::*; pub use apple::*;
#[cfg(any(target_os = "linux", target_vendor = "apple"))]
#[no_mangle]
pub extern "C" fn retrieve() -> i32 {
let iface2 = (1..100)
.filter_map(|i| {
let iface = unsafe { TunInterface::from_raw_fd(i) };
match iface.name() {
Ok(_name) => Some(iface),
Err(_) => {
mem::forget(iface);
None
}
}
})
.next();
match iface2 {
Some(iface) => iface.as_raw_fd(),
None => -1,
}
}

View file

@ -4,22 +4,20 @@ use std::os::fd::FromRawFd;
use anyhow::{Context, Result}; use anyhow::{Context, Result};
#[cfg(any(target_os = "linux", target_vendor = "apple"))] #[cfg(any(target_os = "linux", target_vendor = "apple"))]
use burrow::retrieve;
use clap::{Args, Parser, Subcommand}; use clap::{Args, Parser, Subcommand};
use tracing::{instrument, Level}; use tracing::{instrument, Level};
use tracing_log::LogTracer; use tracing_log::LogTracer;
use tracing_oslog::OsLogger; use tracing_oslog::OsLogger;
use tracing_subscriber::{prelude::*, FmtSubscriber, EnvFilter}; use tracing_subscriber::{prelude::*, EnvFilter, FmtSubscriber};
#[cfg(any(target_os = "linux", target_vendor = "apple"))] #[cfg(any(target_os = "linux", target_vendor = "apple"))]
use tun::TunInterface; use tun::{retrieve, TunInterface};
mod daemon; mod daemon;
mod wireguard; mod wireguard;
use crate::daemon::DaemonResponseData;
use daemon::{DaemonClient, DaemonCommand, DaemonStartOptions}; use daemon::{DaemonClient, DaemonCommand, DaemonStartOptions};
use tun::TunOptions; use tun::TunOptions;
use crate::daemon::DaemonResponseData;
#[derive(Parser)] #[derive(Parser)]
#[command(name = "Burrow")] #[command(name = "Burrow")]
@ -66,11 +64,9 @@ struct DaemonArgs {}
async fn try_start() -> Result<()> { async fn try_start() -> Result<()> {
let mut client = DaemonClient::new().await?; let mut client = DaemonClient::new().await?;
client client
.send_command(DaemonCommand::Start( .send_command(DaemonCommand::Start(DaemonStartOptions {
DaemonStartOptions{ tun: TunOptions::new().address("10.13.13.2"),
tun: TunOptions::new().address("10.13.13.2") }))
}
))
.await .await
.map(|_| ()) .map(|_| ())
} }
@ -93,8 +89,8 @@ async fn try_retrieve() -> Result<()> {
} }
burrow::ensureroot::ensure_root(); burrow::ensureroot::ensure_root();
let iface2 = retrieve(); let iface2 = retrieve().ok_or(anyhow::anyhow!("No interface found"))?;
tracing::info!("{}", iface2); tracing::info!("{:?}", iface2);
Ok(()) Ok(())
} }
@ -109,9 +105,10 @@ async fn initialize_tracing() -> Result<()> {
FmtSubscriber::builder() FmtSubscriber::builder()
.with_line_number(true) .with_line_number(true)
.with_env_filter(EnvFilter::from_default_env()) .with_env_filter(EnvFilter::from_default_env())
.finish() .finish(),
); );
tracing::subscriber::set_global_default(logger).context("Failed to set the global tracing subscriber")?; tracing::subscriber::set_global_default(logger)
.context("Failed to set the global tracing subscriber")?;
} }
} }
@ -126,7 +123,7 @@ async fn try_stop() -> Result<()> {
} }
#[cfg(any(target_os = "linux", target_vendor = "apple"))] #[cfg(any(target_os = "linux", target_vendor = "apple"))]
async fn try_serverinfo() -> Result<()>{ async fn try_serverinfo() -> Result<()> {
let mut client = DaemonClient::new().await?; let mut client = DaemonClient::new().await?;
let res = client.send_command(DaemonCommand::ServerInfo).await?; let res = client.send_command(DaemonCommand::ServerInfo).await?;
match res.result { match res.result {
@ -136,7 +133,9 @@ async fn try_serverinfo() -> Result<()>{
Ok(DaemonResponseData::None) => { Ok(DaemonResponseData::None) => {
println!("Server not started.") println!("Server not started.")
} }
Ok(res) => {println!("Unexpected Response: {:?}", res)} Ok(res) => {
println!("Unexpected Response: {:?}", res)
}
Err(e) => { Err(e) => {
println!("Error when retrieving from server: {}", e) println!("Error when retrieving from server: {}", e)
} }
@ -145,7 +144,7 @@ async fn try_serverinfo() -> Result<()>{
} }
#[cfg(any(target_os = "linux", target_vendor = "apple"))] #[cfg(any(target_os = "linux", target_vendor = "apple"))]
async fn try_serverconfig() -> Result<()>{ async fn try_serverconfig() -> Result<()> {
let mut client = DaemonClient::new().await?; let mut client = DaemonClient::new().await?;
let res = client.send_command(DaemonCommand::ServerConfig).await?; let res = client.send_command(DaemonCommand::ServerConfig).await?;
match res.result { match res.result {
@ -155,7 +154,9 @@ async fn try_serverconfig() -> Result<()>{
Ok(DaemonResponseData::None) => { Ok(DaemonResponseData::None) => {
println!("Server not started.") println!("Server not started.")
} }
Ok(res) => {println!("Unexpected Response: {:?}", res)} Ok(res) => {
println!("Unexpected Response: {:?}", res)
}
Err(e) => { Err(e) => {
println!("Error when retrieving from server: {}", e) println!("Error when retrieving from server: {}", e)
} }
@ -206,12 +207,8 @@ async fn main() -> Result<()> {
try_stop().await?; try_stop().await?;
} }
Commands::Daemon(_) => daemon::daemon_main().await?, Commands::Daemon(_) => daemon::daemon_main().await?,
Commands::ServerInfo => { Commands::ServerInfo => try_serverinfo().await?,
try_serverinfo().await? Commands::ServerConfig => try_serverconfig().await?,
}
Commands::ServerConfig => {
try_serverconfig().await?
}
} }
Ok(()) Ok(())

View file

@ -1,21 +1,22 @@
use std::{net::IpAddr, rc::Rc};
use std::sync::Arc; use std::sync::Arc;
use std::time::Duration; use std::time::Duration;
use std::{net::IpAddr, rc::Rc};
use anyhow::Error; use anyhow::Error;
use async_trait::async_trait; use async_trait::async_trait;
use fehler::throws; use fehler::throws;
use futures::future::join_all;
use futures::FutureExt;
use ip_network_table::IpNetworkTable; use ip_network_table::IpNetworkTable;
use log::log; use log::log;
use tokio::time::timeout;
use tokio::{ use tokio::{
join, join,
sync::{Mutex, RwLock}, sync::{Mutex, RwLock},
task::{self, JoinHandle}, task::{self, JoinHandle},
}; };
use tracing::{debug, error};
use tun::tokio::TunInterface; use tun::tokio::TunInterface;
use futures::future::join_all;
use futures::FutureExt;
use tokio::time::timeout;
use super::{noise::Tunnel, pcb, Peer, PeerPcb}; use super::{noise::Tunnel, pcb, Peer, PeerPcb};
@ -90,10 +91,7 @@ impl Interface {
.collect::<Result<_, _>>()?; .collect::<Result<_, _>>()?;
let pcbs = Arc::new(pcbs); let pcbs = Arc::new(pcbs);
Self { Self { pcbs, tun: None }
pcbs,
tun: None
}
} }
pub fn set_tun(&mut self, tun: Arc<RwLock<TunInterface>>) { pub fn set_tun(&mut self, tun: Arc<RwLock<TunInterface>>) {
@ -101,66 +99,82 @@ impl Interface {
} }
pub async fn run(&self) -> anyhow::Result<()> { pub async fn run(&self) -> anyhow::Result<()> {
debug!("RUN: starting interface");
let pcbs = self.pcbs.clone(); let pcbs = self.pcbs.clone();
let tun = self.tun.clone().ok_or(anyhow::anyhow!("tun interface does not exist"))?; let tun = self
.tun
.clone()
.ok_or(anyhow::anyhow!("tun interface does not exist"))?;
log::info!("starting interface"); log::info!("starting interface");
let outgoing = async move { let outgoing = async move {
loop { loop {
// log::debug!("starting loop..."); // tracing::debug!("starting loop...");
let mut buf = [0u8; 3000]; let mut buf = [0u8; 3000];
let src = { let src = {
// log::debug!("awaiting read..."); let src = match timeout(
let src = match timeout(Duration::from_millis(10), tun.write().await.recv(&mut buf[..])).await { Duration::from_millis(10),
tun.write().await.recv(&mut buf[..]),
)
.await
{
Ok(Ok(len)) => &buf[..len], Ok(Ok(len)) => &buf[..len],
Ok(Err(e)) => {continue} Ok(Err(e)) => {
error!("failed to read from interface: {}", e);
continue;
}
Err(_would_block) => { Err(_would_block) => {
continue debug!("read timed out");
continue;
} }
}; };
log::debug!("read {} bytes from interface", src.len()); debug!("read {} bytes from interface", src.len());
log::debug!("bytes: {:?}", src); debug!("bytes: {:?}", src);
src src
}; };
let dst_addr = match Tunnel::dst_address(src) { let dst_addr = match Tunnel::dst_address(src) {
Some(addr) => addr, Some(addr) => addr,
None => { None => {
log::debug!("no destination found"); tracing::debug!("no destination found");
continue continue;
}, }
}; };
log::debug!("dst_addr: {}", dst_addr); tracing::debug!("dst_addr: {}", dst_addr);
let Some(idx) = pcbs.find(dst_addr) else { let Some(idx) = pcbs.find(dst_addr) else {
continue continue
}; };
log::debug!("found peer:{}", idx); tracing::debug!("found peer:{}", idx);
match pcbs.pcbs[idx].read().await.send(src).await { match pcbs.pcbs[idx].read().await.send(src).await {
Ok(..) => { Ok(..) => {
log::debug!("sent packet to peer {}", dst_addr); tracing::debug!("sent packet to peer {}", dst_addr);
} }
Err(e) => { Err(e) => {
log::error!("failed to send packet {}", e); log::error!("failed to send packet {}", e);
continue continue;
}, }
}; };
} }
}; };
let mut tsks = vec![]; let mut tsks = vec![];
let tun = self.tun.clone().ok_or(anyhow::anyhow!("tun interface does not exist"))?; let tun = self
.tun
.clone()
.ok_or(anyhow::anyhow!("tun interface does not exist"))?;
let outgoing = tokio::task::spawn(outgoing); let outgoing = tokio::task::spawn(outgoing);
tsks.push(outgoing); tsks.push(outgoing);
debug!("preparing to spawn read tasks");
{ {
let pcbs = &self.pcbs; let pcbs = &self.pcbs;
for i in 0..pcbs.pcbs.len(){ for i in 0..pcbs.pcbs.len() {
debug!("spawning read task for peer {}", i);
let mut pcb = pcbs.pcbs[i].clone(); let mut pcb = pcbs.pcbs[i].clone();
let tun = tun.clone(); let tun = tun.clone();
let tsk = async move { let tsk = async move {
@ -168,23 +182,25 @@ impl Interface {
let r1 = pcb.write().await.open_if_closed().await; let r1 = pcb.write().await.open_if_closed().await;
if let Err(e) = r1 { if let Err(e) = r1 {
log::error!("failed to open pcb: {}", e); log::error!("failed to open pcb: {}", e);
return return;
} }
} }
let r2 = pcb.read().await.run(tun).await; let r2 = pcb.read().await.run(tun).await;
if let Err(e) = r2 { if let Err(e) = r2 {
log::error!("failed to run pcb: {}", e); log::error!("failed to run pcb: {}", e);
return return;
} else { } else {
log::debug!("pcb ran successfully"); tracing::debug!("pcb ran successfully");
} }
}; };
debug!("task made..");
tsks.push(tokio::spawn(tsk)); tsks.push(tokio::spawn(tsk));
} }
log::debug!("spawned read tasks"); debug!("spawned read tasks");
} }
log::debug!("preparing to join.."); debug!("preparing to join..");
join_all(tsks).await; join_all(tsks).await;
debug!("joined!");
Ok(()) Ok(())
} }
} }

View file

@ -9,20 +9,14 @@ use std::{
use aead::{Aead, Payload}; use aead::{Aead, Payload};
use blake2::{ use blake2::{
digest::{FixedOutput, KeyInit}, digest::{FixedOutput, KeyInit},
Blake2s256, Blake2s256, Blake2sMac, Digest,
Blake2sMac,
Digest,
}; };
use chacha20poly1305::XChaCha20Poly1305; use chacha20poly1305::XChaCha20Poly1305;
use rand_core::OsRng; use rand_core::OsRng;
use ring::aead::{Aad, LessSafeKey, Nonce, UnboundKey, CHACHA20_POLY1305}; use ring::aead::{Aad, LessSafeKey, Nonce, UnboundKey, CHACHA20_POLY1305};
use super::{ use super::{
errors::WireGuardError, errors::WireGuardError, session::Session, x25519, HandshakeInit, HandshakeResponse,
session::Session,
x25519,
HandshakeInit,
HandshakeResponse,
PacketCookieReply, PacketCookieReply,
}; };
@ -137,8 +131,8 @@ fn aead_chacha20_open(
let mut nonce: [u8; 12] = [0; 12]; let mut nonce: [u8; 12] = [0; 12];
nonce[4..].copy_from_slice(&counter.to_le_bytes()); nonce[4..].copy_from_slice(&counter.to_le_bytes());
log::debug!("TAG A"); tracing::debug!("TAG A");
log::debug!("{:?};{:?};{:?};{:?};{}", key, data, aad, nonce, counter); tracing::debug!("{:?};{:?};{:?};{:?};{}", key, data, aad, nonce, counter);
aead_chacha20_open_inner(buffer, key, nonce, data, aad) aead_chacha20_open_inner(buffer, key, nonce, data, aad)
.map_err(|_| WireGuardError::InvalidAeadTag)?; .map_err(|_| WireGuardError::InvalidAeadTag)?;
@ -213,7 +207,7 @@ impl Tai64N {
/// Parse a timestamp from a 12 byte u8 slice /// Parse a timestamp from a 12 byte u8 slice
fn parse(buf: &[u8; 12]) -> Result<Tai64N, WireGuardError> { fn parse(buf: &[u8; 12]) -> Result<Tai64N, WireGuardError> {
if buf.len() < 12 { if buf.len() < 12 {
return Err(WireGuardError::InvalidTai64nTimestamp) return Err(WireGuardError::InvalidTai64nTimestamp);
} }
let (sec_bytes, nano_bytes) = buf.split_at(std::mem::size_of::<u64>()); let (sec_bytes, nano_bytes) = buf.split_at(std::mem::size_of::<u64>());
@ -560,19 +554,22 @@ impl Handshake {
let timestamp = Tai64N::parse(&timestamp)?; let timestamp = Tai64N::parse(&timestamp)?;
if !timestamp.after(&self.last_handshake_timestamp) { if !timestamp.after(&self.last_handshake_timestamp) {
// Possibly a replay // Possibly a replay
return Err(WireGuardError::WrongTai64nTimestamp) return Err(WireGuardError::WrongTai64nTimestamp);
} }
self.last_handshake_timestamp = timestamp; self.last_handshake_timestamp = timestamp;
// initiator.hash = HASH(initiator.hash || msg.encrypted_timestamp) // initiator.hash = HASH(initiator.hash || msg.encrypted_timestamp)
hash = b2s_hash(&hash, packet.encrypted_timestamp); hash = b2s_hash(&hash, packet.encrypted_timestamp);
self.previous = std::mem::replace(&mut self.state, HandshakeState::InitReceived { self.previous = std::mem::replace(
&mut self.state,
HandshakeState::InitReceived {
chaining_key, chaining_key,
hash, hash,
peer_ephemeral_public, peer_ephemeral_public,
peer_index, peer_index,
}); },
);
self.format_handshake_response(dst) self.format_handshake_response(dst)
} }
@ -673,7 +670,7 @@ impl Handshake {
let local_index = self.cookies.index; let local_index = self.cookies.index;
if packet.receiver_idx != local_index { if packet.receiver_idx != local_index {
return Err(WireGuardError::WrongIndex) return Err(WireGuardError::WrongIndex);
} }
// msg.encrypted_cookie = XAEAD(HASH(LABEL_COOKIE || responder.static_public), // msg.encrypted_cookie = XAEAD(HASH(LABEL_COOKIE || responder.static_public),
// msg.nonce, cookie, last_received_msg.mac1) // msg.nonce, cookie, last_received_msg.mac1)
@ -683,7 +680,7 @@ impl Handshake {
aad: &mac1[0..16], aad: &mac1[0..16],
msg: packet.encrypted_cookie, msg: packet.encrypted_cookie,
}; };
log::debug!("TAG B"); tracing::debug!("TAG B");
let plaintext = XChaCha20Poly1305::new_from_slice(&key) let plaintext = XChaCha20Poly1305::new_from_slice(&key)
.unwrap() .unwrap()
.decrypt(packet.nonce.into(), payload) .decrypt(packet.nonce.into(), payload)
@ -730,7 +727,7 @@ impl Handshake {
dst: &'a mut [u8], dst: &'a mut [u8],
) -> Result<&'a mut [u8], WireGuardError> { ) -> Result<&'a mut [u8], WireGuardError> {
if dst.len() < super::HANDSHAKE_INIT_SZ { if dst.len() < super::HANDSHAKE_INIT_SZ {
return Err(WireGuardError::DestinationBufferTooSmall) return Err(WireGuardError::DestinationBufferTooSmall);
} }
let (message_type, rest) = dst.split_at_mut(4); let (message_type, rest) = dst.split_at_mut(4);
@ -813,7 +810,7 @@ impl Handshake {
dst: &'a mut [u8], dst: &'a mut [u8],
) -> Result<(&'a mut [u8], Session), WireGuardError> { ) -> Result<(&'a mut [u8], Session), WireGuardError> {
if dst.len() < super::HANDSHAKE_RESP_SZ { if dst.len() < super::HANDSHAKE_RESP_SZ {
return Err(WireGuardError::DestinationBufferTooSmall) return Err(WireGuardError::DestinationBufferTooSmall);
} }
let state = std::mem::replace(&mut self.state, HandshakeState::None); let state = std::mem::replace(&mut self.state, HandshakeState::None);

View file

@ -45,11 +45,7 @@ const N_SESSIONS: usize = 8;
pub mod x25519 { pub mod x25519 {
pub use x25519_dalek::{ pub use x25519_dalek::{
EphemeralSecret, EphemeralSecret, PublicKey, ReusableSecret, SharedSecret, StaticSecret,
PublicKey,
ReusableSecret,
SharedSecret,
StaticSecret,
}; };
} }
@ -141,12 +137,12 @@ impl Tunnel {
#[inline(always)] #[inline(always)]
pub fn parse_incoming_packet(src: &[u8]) -> Result<Packet, WireGuardError> { pub fn parse_incoming_packet(src: &[u8]) -> Result<Packet, WireGuardError> {
if src.len() < 4 { if src.len() < 4 {
return Err(WireGuardError::InvalidPacket) return Err(WireGuardError::InvalidPacket);
} }
// Checks the type, as well as the reserved zero fields // Checks the type, as well as the reserved zero fields
let packet_type = u32::from_le_bytes(src[0..4].try_into().unwrap()); let packet_type = u32::from_le_bytes(src[0..4].try_into().unwrap());
log::debug!("packet_type: {}", packet_type); tracing::debug!("packet_type: {}", packet_type);
Ok(match (packet_type, src.len()) { Ok(match (packet_type, src.len()) {
(HANDSHAKE_INIT, HANDSHAKE_INIT_SZ) => Packet::HandshakeInit(HandshakeInit { (HANDSHAKE_INIT, HANDSHAKE_INIT_SZ) => Packet::HandshakeInit(HandshakeInit {
@ -183,7 +179,7 @@ impl Tunnel {
pub fn dst_address(packet: &[u8]) -> Option<IpAddr> { pub fn dst_address(packet: &[u8]) -> Option<IpAddr> {
if packet.is_empty() { if packet.is_empty() {
return None return None;
} }
match packet[0] >> 4 { match packet[0] >> 4 {
@ -278,7 +274,7 @@ impl Tunnel {
self.timer_tick(TimerName::TimeLastDataPacketSent); self.timer_tick(TimerName::TimeLastDataPacketSent);
} }
self.tx_bytes += src.len(); self.tx_bytes += src.len();
return TunnResult::WriteToNetwork(packet) return TunnResult::WriteToNetwork(packet);
} }
// If there is no session, queue the packet for future retry // If there is no session, queue the packet for future retry
@ -302,7 +298,7 @@ impl Tunnel {
) -> TunnResult<'a> { ) -> TunnResult<'a> {
if datagram.is_empty() { if datagram.is_empty() {
// Indicates a repeated call // Indicates a repeated call
return self.send_queued_packet(dst) return self.send_queued_packet(dst);
} }
let mut cookie = [0u8; COOKIE_REPLY_SZ]; let mut cookie = [0u8; COOKIE_REPLY_SZ];
@ -313,7 +309,7 @@ impl Tunnel {
Ok(packet) => packet, Ok(packet) => packet,
Err(TunnResult::WriteToNetwork(cookie)) => { Err(TunnResult::WriteToNetwork(cookie)) => {
dst[..cookie.len()].copy_from_slice(cookie); dst[..cookie.len()].copy_from_slice(cookie);
return TunnResult::WriteToNetwork(&mut dst[..cookie.len()]) return TunnResult::WriteToNetwork(&mut dst[..cookie.len()]);
} }
Err(TunnResult::Err(e)) => return TunnResult::Err(e), Err(TunnResult::Err(e)) => return TunnResult::Err(e),
_ => unreachable!(), _ => unreachable!(),
@ -413,7 +409,7 @@ impl Tunnel {
let cur_idx = self.current; let cur_idx = self.current;
if cur_idx == new_idx { if cur_idx == new_idx {
// There is nothing to do, already using this session, this is the common case // There is nothing to do, already using this session, this is the common case
return return;
} }
if self.sessions[cur_idx % N_SESSIONS].is_none() if self.sessions[cur_idx % N_SESSIONS].is_none()
|| self.timers.session_timers[new_idx % N_SESSIONS] || self.timers.session_timers[new_idx % N_SESSIONS]
@ -459,7 +455,7 @@ impl Tunnel {
force_resend: bool, force_resend: bool,
) -> TunnResult<'a> { ) -> TunnResult<'a> {
if self.handshake.is_in_progress() && !force_resend { if self.handshake.is_in_progress() && !force_resend {
return TunnResult::Done return TunnResult::Done;
} }
if self.handshake.is_expired() { if self.handshake.is_expired() {
@ -518,7 +514,7 @@ impl Tunnel {
}; };
if computed_len > packet.len() { if computed_len > packet.len() {
return TunnResult::Err(WireGuardError::InvalidPacket) return TunnResult::Err(WireGuardError::InvalidPacket);
} }
self.timer_tick(TimerName::TimeLastDataPacketReceived); self.timer_tick(TimerName::TimeLastDataPacketReceived);

View file

@ -13,19 +13,9 @@ use ring::constant_time::verify_slices_are_equal;
use super::{ use super::{
handshake::{ handshake::{
b2s_hash, b2s_hash, b2s_keyed_mac_16, b2s_keyed_mac_16_2, b2s_mac_24, LABEL_COOKIE, LABEL_MAC1,
b2s_keyed_mac_16,
b2s_keyed_mac_16_2,
b2s_mac_24,
LABEL_COOKIE,
LABEL_MAC1,
}, },
HandshakeInit, HandshakeInit, HandshakeResponse, Packet, TunnResult, Tunnel, WireGuardError,
HandshakeResponse,
Packet,
TunnResult,
Tunnel,
WireGuardError,
}; };
const COOKIE_REFRESH: u64 = 128; // Use 128 and not 120 so the compiler can optimize out the division const COOKIE_REFRESH: u64 = 128; // Use 128 and not 120 so the compiler can optimize out the division
@ -137,7 +127,7 @@ impl RateLimiter {
dst: &'a mut [u8], dst: &'a mut [u8],
) -> Result<&'a mut [u8], WireGuardError> { ) -> Result<&'a mut [u8], WireGuardError> {
if dst.len() < super::COOKIE_REPLY_SZ { if dst.len() < super::COOKIE_REPLY_SZ {
return Err(WireGuardError::DestinationBufferTooSmall) return Err(WireGuardError::DestinationBufferTooSmall);
} }
let (message_type, rest) = dst.split_at_mut(4); let (message_type, rest) = dst.split_at_mut(4);
@ -174,14 +164,14 @@ impl RateLimiter {
dst: &'b mut [u8], dst: &'b mut [u8],
) -> Result<Packet<'a>, TunnResult<'b>> { ) -> Result<Packet<'a>, TunnResult<'b>> {
let packet = Tunnel::parse_incoming_packet(src)?; let packet = Tunnel::parse_incoming_packet(src)?;
log::debug!("packet: {:?}", packet); tracing::debug!("packet: {:?}", packet);
// Verify and rate limit handshake messages only // Verify and rate limit handshake messages only
if let Packet::HandshakeInit(HandshakeInit { sender_idx, .. }) if let Packet::HandshakeInit(HandshakeInit { sender_idx, .. })
| Packet::HandshakeResponse(HandshakeResponse { sender_idx, .. }) = packet | Packet::HandshakeResponse(HandshakeResponse { sender_idx, .. }) = packet
{ {
log::debug!("sender_idx: {}", sender_idx); tracing::debug!("sender_idx: {}", sender_idx);
log::debug!("response: {:?}", packet); tracing::debug!("response: {:?}", packet);
let (msg, macs) = src.split_at(src.len() - 32); let (msg, macs) = src.split_at(src.len() - 32);
let (mac1, mac2) = macs.split_at(16); let (mac1, mac2) = macs.split_at(16);
@ -203,7 +193,7 @@ impl RateLimiter {
let cookie_packet = self let cookie_packet = self
.format_cookie_reply(sender_idx, cookie, mac1, dst) .format_cookie_reply(sender_idx, cookie, mac1, dst)
.map_err(TunnResult::Err)?; .map_err(TunnResult::Err)?;
return Err(TunnResult::WriteToNetwork(cookie_packet)) return Err(TunnResult::WriteToNetwork(cookie_packet));
} }
} }
} }

View file

@ -88,11 +88,11 @@ impl ReceivingKeyCounterValidator {
fn will_accept(&self, counter: u64) -> Result<(), WireGuardError> { fn will_accept(&self, counter: u64) -> Result<(), WireGuardError> {
if counter >= self.next { if counter >= self.next {
// As long as the counter is growing no replay took place for sure // As long as the counter is growing no replay took place for sure
return Ok(()) return Ok(());
} }
if counter + N_BITS < self.next { if counter + N_BITS < self.next {
// Drop if too far back // Drop if too far back
return Err(WireGuardError::InvalidCounter) return Err(WireGuardError::InvalidCounter);
} }
if !self.check_bit(counter) { if !self.check_bit(counter) {
Ok(()) Ok(())
@ -107,22 +107,22 @@ impl ReceivingKeyCounterValidator {
fn mark_did_receive(&mut self, counter: u64) -> Result<(), WireGuardError> { fn mark_did_receive(&mut self, counter: u64) -> Result<(), WireGuardError> {
if counter + N_BITS < self.next { if counter + N_BITS < self.next {
// Drop if too far back // Drop if too far back
return Err(WireGuardError::InvalidCounter) return Err(WireGuardError::InvalidCounter);
} }
if counter == self.next { if counter == self.next {
// Usually the packets arrive in order, in that case we simply mark the bit and // Usually the packets arrive in order, in that case we simply mark the bit and
// increment the counter // increment the counter
self.set_bit(counter); self.set_bit(counter);
self.next += 1; self.next += 1;
return Ok(()) return Ok(());
} }
if counter < self.next { if counter < self.next {
// A packet arrived out of order, check if it is valid, and mark // A packet arrived out of order, check if it is valid, and mark
if self.check_bit(counter) { if self.check_bit(counter) {
return Err(WireGuardError::InvalidCounter) return Err(WireGuardError::InvalidCounter);
} }
self.set_bit(counter); self.set_bit(counter);
return Ok(()) return Ok(());
} }
// Packets where dropped, or maybe reordered, skip them and mark unused // Packets where dropped, or maybe reordered, skip them and mark unused
if counter - self.next >= N_BITS { if counter - self.next >= N_BITS {
@ -247,13 +247,13 @@ impl Session {
panic!("The destination buffer is too small"); panic!("The destination buffer is too small");
} }
if packet.receiver_idx != self.receiving_index { if packet.receiver_idx != self.receiving_index {
return Err(WireGuardError::WrongIndex) return Err(WireGuardError::WrongIndex);
} }
// Don't reuse counters, in case this is a replay attack we want to quickly // Don't reuse counters, in case this is a replay attack we want to quickly
// check the counter without running expensive decryption // check the counter without running expensive decryption
self.receiving_counter_quick_check(packet.counter)?; self.receiving_counter_quick_check(packet.counter)?;
log::debug!("TAG C"); tracing::debug!("TAG C");
let ret = { let ret = {
let mut nonce = [0u8; 12]; let mut nonce = [0u8; 12];
nonce[4..12].copy_from_slice(&packet.counter.to_le_bytes()); nonce[4..12].copy_from_slice(&packet.counter.to_le_bytes());

View file

@ -190,7 +190,7 @@ impl Tunnel {
{ {
if self.handshake.is_expired() { if self.handshake.is_expired() {
return TunnResult::Err(WireGuardError::ConnectionExpired) return TunnResult::Err(WireGuardError::ConnectionExpired);
} }
// Clear cookie after COOKIE_EXPIRATION_TIME // Clear cookie after COOKIE_EXPIRATION_TIME
@ -206,7 +206,7 @@ impl Tunnel {
tracing::error!("CONNECTION_EXPIRED(REJECT_AFTER_TIME * 3)"); tracing::error!("CONNECTION_EXPIRED(REJECT_AFTER_TIME * 3)");
self.handshake.set_expired(); self.handshake.set_expired();
self.clear_all(); self.clear_all();
return TunnResult::Err(WireGuardError::ConnectionExpired) return TunnResult::Err(WireGuardError::ConnectionExpired);
} }
if let Some(time_init_sent) = self.handshake.timer() { if let Some(time_init_sent) = self.handshake.timer() {
@ -219,7 +219,7 @@ impl Tunnel {
tracing::error!("CONNECTION_EXPIRED(REKEY_ATTEMPT_TIME)"); tracing::error!("CONNECTION_EXPIRED(REKEY_ATTEMPT_TIME)");
self.handshake.set_expired(); self.handshake.set_expired();
self.clear_all(); self.clear_all();
return TunnResult::Err(WireGuardError::ConnectionExpired) return TunnResult::Err(WireGuardError::ConnectionExpired);
} }
if time_init_sent.elapsed() >= REKEY_TIMEOUT { if time_init_sent.elapsed() >= REKEY_TIMEOUT {
@ -299,11 +299,11 @@ impl Tunnel {
} }
if handshake_initiation_required { if handshake_initiation_required {
return self.format_handshake_initiation(dst, true) return self.format_handshake_initiation(dst, true);
} }
if keepalive_required { if keepalive_required {
return self.encapsulate(&[], dst) return self.encapsulate(&[], dst);
} }
TunnResult::Done TunnResult::Done

View file

@ -9,11 +9,11 @@ use fehler::throws;
use ip_network::IpNetwork; use ip_network::IpNetwork;
use log::log; use log::log;
use rand::random; use rand::random;
use tokio::{net::UdpSocket, task::JoinHandle};
use tokio::sync::{Mutex, RwLock}; use tokio::sync::{Mutex, RwLock};
use tokio::time::timeout; use tokio::time::timeout;
use uuid::uuid; use tokio::{net::UdpSocket, task::JoinHandle};
use tun::tokio::TunInterface; use tun::tokio::TunInterface;
use uuid::uuid;
use super::{ use super::{
iface::PacketInterface, iface::PacketInterface,
@ -33,15 +33,24 @@ pub struct PeerPcb {
impl PeerPcb { impl PeerPcb {
#[throws] #[throws]
pub fn new(peer: Peer) -> Self { pub fn new(peer: Peer) -> Self {
let tunnel = RwLock::new(Tunnel::new(peer.private_key, peer.public_key, peer.preshared_key, None, 1, None) let tunnel = RwLock::new(
.map_err(|s| anyhow::anyhow!("{}", s))?); Tunnel::new(
peer.private_key,
peer.public_key,
peer.preshared_key,
None,
1,
None,
)
.map_err(|s| anyhow::anyhow!("{}", s))?,
);
Self { Self {
endpoint: peer.endpoint, endpoint: peer.endpoint,
allowed_ips: peer.allowed_ips, allowed_ips: peer.allowed_ips,
handle: None, handle: None,
socket: None, socket: None,
tunnel tunnel,
} }
} }
@ -56,7 +65,7 @@ impl PeerPcb {
pub async fn run(&self, tun_interface: Arc<RwLock<TunInterface>>) -> Result<(), Error> { pub async fn run(&self, tun_interface: Arc<RwLock<TunInterface>>) -> Result<(), Error> {
let mut buf = [0u8; 3000]; let mut buf = [0u8; 3000];
log::debug!("starting read loop for pcb..."); tracing::debug!("starting read loop for pcb...");
loop { loop {
tracing::debug!("waiting for packet"); tracing::debug!("waiting for packet");
let len = self.recv(&mut buf, tun_interface.clone()).await?; let len = self.recv(&mut buf, tun_interface.clone()).await?;
@ -64,29 +73,38 @@ impl PeerPcb {
} }
} }
pub async fn recv(&self, buf: &mut [u8], tun_interface: Arc<RwLock<TunInterface>>) -> Result<usize, Error> { pub async fn recv(
log::debug!("starting read loop for pcb... for {:?}", &self); &self,
buf: &mut [u8],
tun_interface: Arc<RwLock<TunInterface>>,
) -> Result<usize, Error> {
tracing::debug!("starting read loop for pcb... for {:?}", &self);
let rid: i32 = random(); let rid: i32 = random();
log::debug!("start read loop {}", rid); tracing::debug!("start read loop {}", rid);
loop{ loop {
log::debug!("{}: waiting for packet", rid); tracing::debug!("{}: waiting for packet", rid);
let Some(socket) = &self.socket else { let Some(socket) = &self.socket else {
continue continue
}; };
let mut res_buf = [0;1500]; let mut res_buf = [0; 1500];
// log::debug!("{} : waiting for readability on {:?}", rid, socket); // tracing::debug!("{} : waiting for readability on {:?}", rid, socket);
let len = match socket.recv(&mut res_buf).await { let len = match socket.recv(&mut res_buf).await {
Ok(l) => {l} Ok(l) => l,
Err(e) => { Err(e) => {
log::error!("{}: error reading from socket: {:?}", rid, e); log::error!("{}: error reading from socket: {:?}", rid, e);
continue continue;
} }
}; };
let mut res_dat = &res_buf[..len]; let mut res_dat = &res_buf[..len];
tracing::debug!("{}: Decapsulating {} bytes", rid, len); tracing::debug!("{}: Decapsulating {} bytes", rid, len);
tracing::debug!("{:?}", &res_dat); tracing::debug!("{:?}", &res_dat);
loop { loop {
match self.tunnel.write().await.decapsulate(None, res_dat, &mut buf[..]) { match self
.tunnel
.write()
.await
.decapsulate(None, res_dat, &mut buf[..])
{
TunnResult::Done => { TunnResult::Done => {
break; break;
} }
@ -113,7 +131,7 @@ impl PeerPcb {
} }
} }
} }
return Ok(len) return Ok(len);
} }
} }
@ -122,7 +140,6 @@ impl PeerPcb {
Ok(self.socket.as_ref().expect("socket was just opened")) Ok(self.socket.as_ref().expect("socket was just opened"))
} }
pub async fn send(&self, src: &[u8]) -> Result<(), Error> { pub async fn send(&self, src: &[u8]) -> Result<(), Error> {
let mut dst_buf = [0u8; 3000]; let mut dst_buf = [0u8; 3000];
match self.tunnel.write().await.encapsulate(src, &mut dst_buf[..]) { match self.tunnel.write().await.encapsulate(src, &mut dst_buf[..]) {

View file

@ -10,7 +10,7 @@ pub struct Peer {
pub private_key: StaticSecret, pub private_key: StaticSecret,
pub public_key: PublicKey, pub public_key: PublicKey,
pub allowed_ips: Vec<IpNetwork>, pub allowed_ips: Vec<IpNetwork>,
pub preshared_key: Option<[u8; 32]> pub preshared_key: Option<[u8; 32]>,
} }
impl fmt::Debug for Peer { impl fmt::Debug for Peer {

View file

@ -26,7 +26,7 @@ async fn generate(out_dir: &std::path::Path) -> anyhow::Result<()> {
println!("cargo:rerun-if-changed={}", binary_path.to_str().unwrap()); println!("cargo:rerun-if-changed={}", binary_path.to_str().unwrap());
if let (Ok(..), Ok(..)) = (File::open(&bindings_path), File::open(&binary_path)) { if let (Ok(..), Ok(..)) = (File::open(&bindings_path), File::open(&binary_path)) {
return Ok(()) return Ok(());
}; };
let archive = download(out_dir) let archive = download(out_dir)

View file

@ -15,4 +15,4 @@ mod options;
pub mod tokio; pub mod tokio;
pub use options::TunOptions; pub use options::TunOptions;
pub use os_imp::{TunInterface, TunQueue}; pub use os_imp::{retrieve, TunInterface, TunQueue};

View file

@ -5,31 +5,42 @@ use fehler::throws;
use super::TunInterface; use super::TunInterface;
#[derive(Debug, Clone, Default)] #[derive(Debug, Clone, Default)]
#[cfg_attr(feature = "serde", derive(serde::Serialize, serde::Deserialize, schemars::JsonSchema))] #[cfg_attr(
feature = "serde",
derive(serde::Serialize, serde::Deserialize, schemars::JsonSchema)
)]
pub struct TunOptions { pub struct TunOptions {
/// (Windows + Linux) Name the tun interface. /// (Windows + Linux) Name the tun interface.
pub(crate) name: Option<String>, pub name: Option<String>,
/// (Linux) Don't include packet information. /// (Linux) Don't include packet information.
pub(crate) no_pi: Option<()>, pub no_pi: bool,
/// (Linux) Avoid opening an existing persistant device. /// (Linux) Avoid opening an existing persistant device.
pub(crate) tun_excl: Option<()>, pub tun_excl: bool,
/// (MacOS) Whether to seek the first available utun device. /// (MacOS) Whether to seek the first available utun device.
pub(crate) seek_utun: Option<()>, pub seek_utun: Option<i32>,
/// (Linux) The IP address of the tun interface. /// (Linux) The IP address of the tun interface.
pub(crate) address: Option<String>, pub address: Option<String>,
} }
impl TunOptions { impl TunOptions {
pub fn new() -> Self { Self::default() } pub fn new() -> Self {
Self::default()
}
pub fn name(mut self, name: &str) -> Self { pub fn name(mut self, name: &str) -> Self {
self.name = Some(name.to_owned()); self.name = Some(name.to_owned());
self self
} }
pub fn no_pi(mut self, enable: bool) { self.no_pi = enable.then_some(()); } pub fn no_pi(mut self, enable: bool) -> Self {
self.no_pi = enable;
self
}
pub fn tun_excl(mut self, enable: bool) { self.tun_excl = enable.then_some(()); } pub fn tun_excl(mut self, enable: bool) -> Self {
self.tun_excl = enable;
self
}
pub fn address(mut self, address: impl ToString) -> Self { pub fn address(mut self, address: impl ToString) -> Self {
self.address = Some(address.to_string()); self.address = Some(address.to_string());
@ -37,5 +48,7 @@ impl TunOptions {
} }
#[throws] #[throws]
pub fn open(self) -> TunInterface { TunInterface::new_with_options(self)? } pub fn open(self) -> TunInterface {
TunInterface::new_with_options(self)?
}
} }

View file

@ -12,9 +12,7 @@ impl TunInterface {
#[instrument] #[instrument]
pub fn new(mut tun: crate::TunInterface) -> io::Result<Self> { pub fn new(mut tun: crate::TunInterface) -> io::Result<Self> {
tun.set_nonblocking(true)?; tun.set_nonblocking(true)?;
Ok(Self { Ok(Self { inner: AsyncFd::new(tun)? })
inner: AsyncFd::new(tun)?,
})
} }
#[instrument] #[instrument]
@ -31,22 +29,22 @@ impl TunInterface {
// #[instrument] // #[instrument]
pub async fn recv(&mut self, buf: &mut [u8]) -> io::Result<usize> { pub async fn recv(&mut self, buf: &mut [u8]) -> io::Result<usize> {
loop { loop {
// log::debug!("TunInterface receiving..."); // tracing::debug!("TunInterface receiving...");
let mut guard = self.inner.readable_mut().await?; let mut guard = self.inner.readable_mut().await?;
// log::debug!("Got! readable_mut"); // tracing::debug!("Got! readable_mut");
match guard.try_io(|inner| { match guard.try_io(|inner| {
let raw_ref = (*inner).get_mut(); let raw_ref = (*inner).get_mut();
let recved = raw_ref.recv(buf); let recved = raw_ref.recv(buf);
recved recved
}) { }) {
Ok(result) => { Ok(result) => {
log::debug!("HORRAY"); tracing::debug!("HORRAY");
return result return result;
}, }
Err(_would_block) => { Err(_would_block) => {
log::debug!("WouldBlock"); tracing::debug!("WouldBlock");
continue continue;
}, }
} }
} }
} }

View file

@ -9,11 +9,12 @@ use byteorder::{ByteOrder, NetworkEndian};
use fehler::throws; use fehler::throws;
use libc::{c_char, iovec, writev, AF_INET, AF_INET6}; use libc::{c_char, iovec, writev, AF_INET, AF_INET6};
use socket2::{Domain, SockAddr, Socket, Type}; use socket2::{Domain, SockAddr, Socket, Type};
use tracing::{self, instrument}; use tracing::{self, debug, instrument};
mod kern_control; mod kern_control;
mod sys; pub mod sys;
use crate::retrieve;
use kern_control::SysControlSocket; use kern_control::SysControlSocket;
pub use super::queue::TunQueue; pub use super::queue::TunQueue;
@ -34,8 +35,13 @@ impl TunInterface {
#[throws] #[throws]
#[instrument] #[instrument]
pub fn new_with_options(options: TunOptions) -> TunInterface { pub fn new_with_options(options: TunOptions) -> TunInterface {
let ti = TunInterface::connect(0)?; debug!("Opening tun interface with options: {:?}", &options);
if let Some(addr) = options.address{ let ti = if let Some(n) = options.seek_utun {
retrieve().ok_or(Error::new(std::io::ErrorKind::NotFound, "No utun found"))?
} else {
TunInterface::connect(0)?
};
if let Some(addr) = options.address {
if let Ok(addr) = addr.parse() { if let Ok(addr) = addr.parse() {
ti.set_ipv4_addr(addr)?; ti.set_ipv4_addr(addr)?;
} }

View file

@ -2,20 +2,11 @@ use std::mem;
use libc::{c_char, c_int, c_short, c_uint, c_ulong, sockaddr}; use libc::{c_char, c_int, c_short, c_uint, c_ulong, sockaddr};
pub use libc::{ pub use libc::{
c_void, c_void, sockaddr_ctl, sockaddr_in, socklen_t, AF_SYSTEM, AF_SYS_CONTROL, IFNAMSIZ,
sockaddr_ctl,
sockaddr_in,
socklen_t,
AF_SYSTEM,
AF_SYS_CONTROL,
IFNAMSIZ,
SYSPROTO_CONTROL, SYSPROTO_CONTROL,
}; };
use nix::{ use nix::{
ioctl_read_bad, ioctl_read_bad, ioctl_readwrite, ioctl_write_ptr_bad, request_code_readwrite,
ioctl_readwrite,
ioctl_write_ptr_bad,
request_code_readwrite,
request_code_write, request_code_write,
}; };

View file

@ -26,7 +26,9 @@ pub struct TunInterface {
impl TunInterface { impl TunInterface {
#[throws] #[throws]
#[instrument] #[instrument]
pub fn new() -> TunInterface { Self::new_with_options(TunOptions::new())? } pub fn new() -> TunInterface {
Self::new_with_options(TunOptions::new())?
}
#[throws] #[throws]
#[instrument] #[instrument]
@ -212,5 +214,7 @@ impl TunInterface {
#[throws] #[throws]
#[instrument] #[instrument]
pub fn send(&self, buf: &[u8]) -> usize { self.socket.send(buf)? } pub fn send(&self, buf: &[u8]) -> usize {
self.socket.send(buf)?
}
} }

View file

@ -1,11 +1,13 @@
use std::mem::size_of;
use std::{ use std::{
io::{Error, Read}, io::{Error, Read},
mem,
os::fd::{AsRawFd, FromRawFd, IntoRawFd, RawFd}, os::fd::{AsRawFd, FromRawFd, IntoRawFd, RawFd},
}; };
use tracing::instrument; use tracing::{debug, error, instrument};
use super::TunOptions; use super::{syscall, TunOptions};
mod queue; mod queue;
@ -17,9 +19,13 @@ mod imp;
#[path = "linux/mod.rs"] #[path = "linux/mod.rs"]
mod imp; mod imp;
use crate::os_imp::imp::sys;
use crate::os_imp::imp::sys::resolve_ctl_info;
use fehler::throws; use fehler::throws;
pub use imp::TunInterface; pub use imp::TunInterface;
use libc::{getpeername, sockaddr_ctl, sockaddr_storage, socklen_t, AF_SYSTEM, AF_SYS_CONTROL};
pub use queue::TunQueue; pub use queue::TunQueue;
use socket2::SockAddr;
impl AsRawFd for TunInterface { impl AsRawFd for TunInterface {
fn as_raw_fd(&self) -> RawFd { fn as_raw_fd(&self) -> RawFd {
@ -47,8 +53,8 @@ impl TunInterface {
// there might be a more efficient way to implement this // there might be a more efficient way to implement this
let tmp_buf = &mut [0u8; 1500]; let tmp_buf = &mut [0u8; 1500];
let len = self.socket.read(tmp_buf)?; let len = self.socket.read(tmp_buf)?;
buf[..len-4].copy_from_slice(&tmp_buf[4..len]); buf[..len - 4].copy_from_slice(&tmp_buf[4..len]);
len-4 len - 4
} }
#[throws] #[throws]
@ -76,3 +82,35 @@ pub fn string_to_ifname(name: &str) -> [libc::c_char; libc::IFNAMSIZ] {
buf[..len].copy_from_slice(unsafe { &*(name.as_bytes() as *const _ as *const [libc::c_char]) }); buf[..len].copy_from_slice(unsafe { &*(name.as_bytes() as *const _ as *const [libc::c_char]) });
buf buf
} }
#[cfg(any(target_os = "linux", target_vendor = "apple"))]
pub fn retrieve() -> Option<TunInterface> {
(3..100)
.filter_map(|i| {
let result = unsafe {
let mut addr = sockaddr_ctl {
sc_len: size_of::<sockaddr_ctl>() as u8,
sc_family: 0,
ss_sysaddr: 0,
sc_id: 0,
sc_unit: 0,
sc_reserved: Default::default(),
};
let mut len = mem::size_of::<sockaddr_ctl>() as libc::socklen_t;
let res = syscall!(getpeername(i, &mut addr as *mut _ as *mut _, len as *mut _));
tracing::debug!("getpeername{}: {:?}", i, res);
if res.is_err() {
return None;
}
if addr.sc_family == sys::AF_SYSTEM as u8
&& addr.ss_sysaddr == sys::AF_SYS_CONTROL as u16
{
Some(TunInterface::from_raw_fd(i))
} else {
None
}
};
result
})
.next()
}

View file

@ -25,7 +25,9 @@ impl Debug for TunInterface {
impl TunInterface { impl TunInterface {
#[throws] #[throws]
pub fn new() -> TunInterface { Self::new_with_options(TunOptions::new())? } pub fn new() -> TunInterface {
Self::new_with_options(TunOptions::new())?
}
#[throws] #[throws]
pub(crate) fn new_with_options(options: TunOptions) -> TunInterface { pub(crate) fn new_with_options(options: TunOptions) -> TunInterface {
@ -37,17 +39,18 @@ impl TunInterface {
if handle.is_null() { if handle.is_null() {
unsafe { GetLastError() }.ok()? unsafe { GetLastError() }.ok()?
} }
TunInterface { TunInterface { handle, name: name_owned }
handle,
name: name_owned,
}
} }
pub fn name(&self) -> String { self.name.clone() } pub fn name(&self) -> String {
self.name.clone()
}
} }
impl Drop for TunInterface { impl Drop for TunInterface {
fn drop(&mut self) { unsafe { sys::WINTUN.WintunCloseAdapter(self.handle) } } fn drop(&mut self) {
unsafe { sys::WINTUN.WintunCloseAdapter(self.handle) }
}
} }
pub(crate) mod sys { pub(crate) mod sys {

View file

@ -5,7 +5,9 @@ use tun::TunInterface;
#[test] #[test]
#[throws] #[throws]
fn test_create() { TunInterface::new()?; } fn test_create() {
TunInterface::new()?;
}
#[test] #[test]
#[throws] #[throws]