Compare commits

..

10 commits

Author SHA1 Message Date
Conrad Kramer
104f8215ba Fixed a number of warnings 2023-12-17 19:42:31 -08:00
Jett Chen
76278809ea Add break to loop 2023-12-18 11:28:10 +08:00
Jett Chen
889ed37f80 Update logging 2023-12-17 17:59:37 +08:00
Jett Chen
286ecfa590 Fix Swiftlints 2023-12-17 17:52:13 +08:00
Jett Chen
669eed0dac remove timeout 2023-12-17 03:37:13 +08:00
Jett Chen
9f84fc6efa update for windows 2023-12-17 03:22:18 +08:00
Jett Chen
b60c6ad687 Update tokio dep 2023-12-17 02:48:56 +08:00
Jett Chen
1e7750606f Update dependencies 2023-12-17 02:40:01 +08:00
Jett Chen
2b3ef999b9 Update Windows workflow to support ring 2023-12-17 02:28:50 +08:00
Jett Chen
a756630316 update snapshots 2023-12-17 02:19:41 +08:00
35 changed files with 237 additions and 257 deletions

View file

@ -54,6 +54,10 @@ jobs:
run: | run: |
sudo apt-get update sudo apt-get update
sudo apt-get install -y ${{ join(matrix.packages, ' ') }} sudo apt-get install -y ${{ join(matrix.packages, ' ') }}
- name: Install Windows Deps
if: matrix.os == 'windows-2022'
shell: bash
run: echo "C:\Program Files (x86)\Microsoft Visual Studio\2019\Enterprise\VC\Tools\Llvm\x64\bin" >> $GITHUB_PATH
- name: Install Rust - name: Install Rust
uses: dtolnay/rust-toolchain@stable uses: dtolnay/rust-toolchain@stable
with: with:

View file

@ -7,10 +7,10 @@ enum BurrowError: Error {
case resultIsNone case resultIsNone
} }
protocol Request: Codable where T: Codable{ protocol Request: Codable where CommandT: Codable {
associatedtype T associatedtype CommandT
var id: UInt { get set } var id: UInt { get set }
var command: T { get set } var command: CommandT { get set }
} }
struct BurrowSingleCommand: Request { struct BurrowSingleCommand: Request {
@ -18,27 +18,30 @@ struct BurrowSingleCommand: Request {
var command: String var command: String
} }
struct BurrowRequest<T>: Request where T: Codable{ struct BurrowRequest<T>: Request where T: Codable {
var id: UInt var id: UInt
var command: T var command: T
} }
struct BurrowStartRequest: Codable { struct BurrowStartRequest: Codable {
struct TunOptions: Codable{ struct TunOptions: Codable {
let name: String? let name: String?
let no_pi: Bool let no_pi: Bool
let tun_excl: Bool let tun_excl: Bool
let tun_retrieve: Bool let tun_retrieve: Bool
let address: String? let address: String?
} }
struct StartOptions: Codable{ struct StartOptions: Codable {
let tun: TunOptions let tun: TunOptions
} }
let Start: StartOptions let Start: StartOptions
} }
func start_req_fd(id: UInt) -> BurrowRequest<BurrowStartRequest> { func start_req_fd(id: UInt) -> BurrowRequest<BurrowStartRequest> {
return BurrowRequest(id: id, command: BurrowStartRequest(Start: BurrowStartRequest.StartOptions(tun: BurrowStartRequest.TunOptions(name: nil, no_pi: false, tun_excl: false, tun_retrieve: true, address: nil)))) let command = BurrowStartRequest(Start: BurrowStartRequest.StartOptions(
tun: BurrowStartRequest.TunOptions(name: nil, no_pi: false, tun_excl: false, tun_retrieve: true, address: nil)
))
return BurrowRequest(id: id, command: command)
} }
struct Response<T>: Decodable where T: Decodable { struct Response<T>: Decodable where T: Decodable {

View file

@ -6,7 +6,7 @@ class PacketTunnelProvider: NEPacketTunnelProvider {
let logger = Logger(subsystem: "com.hackclub.burrow", category: "frontend") let logger = Logger(subsystem: "com.hackclub.burrow", category: "frontend")
var client: BurrowIpc? var client: BurrowIpc?
var osInitialized = false var osInitialized = false
override func startTunnel(options: [String : NSObject]? = nil) async throws { override func startTunnel(options: [String: NSObject]? = nil) async throws {
logger.log("Starting tunnel") logger.log("Starting tunnel")
if !osInitialized { if !osInitialized {
libburrow.initialize_oslog() libburrow.initialize_oslog()
@ -34,13 +34,13 @@ class PacketTunnelProvider: NEPacketTunnelProvider {
// let tunFd = self.packetFlow.value(forKeyPath: "socket.fileDescriptor") as! Int; // let tunFd = self.packetFlow.value(forKeyPath: "socket.fileDescriptor") as! Int;
// self.logger.info("Found File Descriptor: \(tunFd)") // self.logger.info("Found File Descriptor: \(tunFd)")
let start_command = start_req_fd(id: 1) let startCommand = start_req_fd(id: 1)
guard let data = try await client?.request(start_command, type: Response<BurrowResult<String>>.self) guard let data = try await client?.request(startCommand, type: Response<BurrowResult<String>>.self)
else { else {
throw BurrowError.cantParseResult throw BurrowError.cantParseResult
} }
let encoded_startres = try JSONEncoder().encode(data.result) let encodedStartRes = try JSONEncoder().encode(data.result)
self.logger.log("Received start server response: \(String(decoding: encoded_startres, as: UTF8.self))") self.logger.log("Received start server response: \(String(decoding: encodedStartRes, as: UTF8.self))")
} catch { } catch {
self.logger.error("An error occurred: \(error)") self.logger.error("An error occurred: \(error)")
throw error throw error
@ -58,15 +58,12 @@ class PacketTunnelProvider: NEPacketTunnelProvider {
return nst return nst
} }
override func stopTunnel(with reason: NEProviderStopReason) async { override func stopTunnel(with reason: NEProviderStopReason) async {
} }
override func handleAppMessage(_ messageData: Data) async -> Data? { override func handleAppMessage(_ messageData: Data) async -> Data? {
messageData messageData
} }
override func sleep() async { override func sleep() async {
} }
override func wake() { override func wake() {
} }
} }

64
Cargo.lock generated
View file

@ -110,24 +110,15 @@ checksum = "96d30a06541fbafbc7f82ed10c06164cfbd2c401138f6addd8404629c4b16711"
[[package]] [[package]]
name = "async-channel" name = "async-channel"
version = "1.9.0" version = "2.1.1"
source = "registry+https://github.com/rust-lang/crates.io-index" source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "81953c529336010edd6d8e358f886d9581267795c61b19475b71314bffa46d35" checksum = "1ca33f4bc4ed1babef42cad36cc1f51fa88be00420404e5b1e80ab1b18f7678c"
dependencies = [ dependencies = [
"concurrent-queue", "concurrent-queue",
"event-listener", "event-listener",
"event-listener-strategy",
"futures-core", "futures-core",
] "pin-project-lite",
[[package]]
name = "async-trait"
version = "0.1.74"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "a66537f1bb974b254c98ed142ff995236e81b9d0fe4db0575f46612cb15eb0f9"
dependencies = [
"proc-macro2",
"quote",
"syn 2.0.39",
] ]
[[package]] [[package]]
@ -251,7 +242,6 @@ dependencies = [
"aead", "aead",
"anyhow", "anyhow",
"async-channel", "async-channel",
"async-trait",
"base64", "base64",
"blake2", "blake2",
"caps", "caps",
@ -639,9 +629,24 @@ dependencies = [
[[package]] [[package]]
name = "event-listener" name = "event-listener"
version = "2.5.3" version = "4.0.0"
source = "registry+https://github.com/rust-lang/crates.io-index" source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "0206175f82b8d6bf6652ff7d71a1e27fd2e4efde587fd368662814d6ec1d9ce0" checksum = "770d968249b5d99410d61f5bf89057f3199a077a04d087092f58e7d10692baae"
dependencies = [
"concurrent-queue",
"parking",
"pin-project-lite",
]
[[package]]
name = "event-listener-strategy"
version = "0.4.0"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "958e4d70b6d5e81971bebec42271ec641e7ff4e170a6fa605f2b8a8b65cb97d3"
dependencies = [
"event-listener",
"pin-project-lite",
]
[[package]] [[package]]
name = "fastrand" name = "fastrand"
@ -1390,6 +1395,12 @@ version = "0.1.1"
source = "registry+https://github.com/rust-lang/crates.io-index" source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "b15813163c1d831bf4a13c3610c05c0d03b39feb07f7e09fa234dac9b15aaf39" checksum = "b15813163c1d831bf4a13c3610c05c0d03b39feb07f7e09fa234dac9b15aaf39"
[[package]]
name = "parking"
version = "2.2.0"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "bb813b8af86854136c6922af0598d719255ecb2179515e6e7730d468f05c9cae"
[[package]] [[package]]
name = "parking_lot" name = "parking_lot"
version = "0.12.1" version = "0.12.1"
@ -1450,9 +1461,9 @@ checksum = "9b2a4787296e9989611394c33f193f676704af1686e70b8f8033ab5ba9a35a94"
[[package]] [[package]]
name = "pin-project-lite" name = "pin-project-lite"
version = "0.2.9" version = "0.2.13"
source = "registry+https://github.com/rust-lang/crates.io-index" source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "e0a7ae3ac2f1173085d398531c705756c94a4c56843785df85a60c1a0afac116" checksum = "8afb450f006bf6385ca15ef45d71d2288452bc3683ce2e2cacc0d18e4be60b58"
[[package]] [[package]]
name = "pin-utils" name = "pin-utils"
@ -1627,17 +1638,16 @@ dependencies = [
[[package]] [[package]]
name = "ring" name = "ring"
version = "0.16.20" version = "0.17.7"
source = "registry+https://github.com/rust-lang/crates.io-index" source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "3053cf52e236a3ed746dfc745aa9cacf1b791d846bdaf412f60a8d7d6e17c8fc" checksum = "688c63d65483050968b2a8937f7995f443e27041a0f7700aa59b0822aedebb74"
dependencies = [ dependencies = [
"cc", "cc",
"getrandom",
"libc", "libc",
"once_cell",
"spin", "spin",
"untrusted", "untrusted",
"web-sys", "windows-sys 0.48.0",
"winapi",
] ]
[[package]] [[package]]
@ -1897,9 +1907,9 @@ dependencies = [
[[package]] [[package]]
name = "spin" name = "spin"
version = "0.5.2" version = "0.9.8"
source = "registry+https://github.com/rust-lang/crates.io-index" source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "6e63cff320ae2c57904679ba7cb63280a3dc4613885beafb148ee7bf9aa9042d" checksum = "6980e8d7511241f8acf4aebddbb1ff938df5eebe98691418c4468d0b72a96a67"
[[package]] [[package]]
name = "ssri" name = "ssri"
@ -2266,9 +2276,9 @@ dependencies = [
[[package]] [[package]]
name = "untrusted" name = "untrusted"
version = "0.7.1" version = "0.9.0"
source = "registry+https://github.com/rust-lang/crates.io-index" source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "a156c684c91ea7d62626509bce3cb4e1d9ed5c4d978f7b4352658f96a4c26b4a" checksum = "8ecb6da28b8a351d773b68d5825ac39017e680750f980f3a1a85cd8dd28a47c1"
[[package]] [[package]]
name = "url" name = "url"

View file

@ -28,7 +28,7 @@ rand = "0.8.5"
rand_core = "0.6.4" rand_core = "0.6.4"
aead = "0.5.2" aead = "0.5.2"
x25519-dalek = { version = "2.0.0", features = ["reusable_secrets", "static_secrets"] } x25519-dalek = { version = "2.0.0", features = ["reusable_secrets", "static_secrets"] }
ring = "0.16.20" ring = "0.17.7"
parking_lot = "0.12.1" parking_lot = "0.12.1"
hmac = "0.12" hmac = "0.12"
ipnet = { version = "2.8.0", features = ["serde"] } ipnet = { version = "2.8.0", features = ["serde"] }
@ -36,8 +36,7 @@ base64 = "0.21.4"
fehler = "1.0.0" fehler = "1.0.0"
ip_network_table = "0.2.0" ip_network_table = "0.2.0"
ip_network = "0.4.0" ip_network = "0.4.0"
async-trait = "0.1.74" async-channel = "2.1.1"
async-channel = "1.9"
schemars = "0.8" schemars = "0.8"
futures = "0.3.28" futures = "0.3.28"
uuid = { version = "1.6.1", features = ["v4"] } uuid = { version = "1.6.1", features = ["v4"] }

View file

@ -41,10 +41,6 @@ impl DaemonInstance {
} }
} }
pub fn set_tun_interface(&mut self, tun_interface: Arc<RwLock<TunInterface>>) {
self.tun_interface = Some(tun_interface);
}
async fn proc_command(&mut self, command: DaemonCommand) -> Result<DaemonResponseData> { async fn proc_command(&mut self, command: DaemonCommand) -> Result<DaemonResponseData> {
info!("Daemon got command: {:?}", command); info!("Daemon got command: {:?}", command);
match command { match command {
@ -111,7 +107,6 @@ impl DaemonInstance {
} }
pub async fn run(&mut self) -> Result<()> { pub async fn run(&mut self) -> Result<()> {
tracing::info!("BEGIN");
while let Ok(command) = self.rx.recv().await { while let Ok(command) = self.rx.recv().await {
let response = self.proc_command(command).await; let response = self.proc_command(command).await;
info!("Daemon response: {:?}", response); info!("Daemon response: {:?}", response);

View file

@ -1,52 +1,24 @@
use std::net::ToSocketAddrs; use std::sync::Arc;
use std::{
net::{IpAddr, Ipv4Addr, SocketAddr},
sync::Arc,
};
mod command; mod command;
mod instance; mod instance;
mod net; mod net;
mod response; mod response;
use anyhow::{anyhow, Error, Result}; use anyhow::Result;
use base64::{engine::general_purpose, Engine as _};
pub use command::{DaemonCommand, DaemonStartOptions}; pub use command::{DaemonCommand, DaemonStartOptions};
use fehler::throws;
use instance::DaemonInstance; use instance::DaemonInstance;
use ip_network::{IpNetwork, Ipv4Network};
#[cfg(target_vendor = "apple")] #[cfg(target_vendor = "apple")]
pub use net::start_srv; pub use net::start_srv;
pub use net::DaemonClient; pub use net::DaemonClient;
pub use response::{DaemonResponse, DaemonResponseData, ServerInfo}; pub use response::{DaemonResponse, DaemonResponseData, ServerInfo};
use tokio::sync::RwLock; use tokio::sync::RwLock;
use crate::wireguard::Config;
use crate::{ use crate::{
daemon::net::listen, daemon::net::listen,
wireguard::{Interface, Peer, PublicKey, StaticSecret}, wireguard::{Config, Interface},
}; };
#[throws]
fn parse_key(string: &str) -> [u8; 32] {
let value = general_purpose::STANDARD.decode(string)?;
let mut key = [0u8; 32];
key.copy_from_slice(&value[..]);
key
}
#[throws]
fn parse_secret_key(string: &str) -> StaticSecret {
let key = parse_key(string)?;
StaticSecret::from(key)
}
#[throws]
fn parse_public_key(string: &str) -> PublicKey {
let key = parse_key(string)?;
PublicKey::from(key)
}
pub async fn daemon_main() -> Result<()> { pub async fn daemon_main() -> Result<()> {
let (commands_tx, commands_rx) = async_channel::unbounded(); let (commands_tx, commands_rx) = async_channel::unbounded();
let (response_tx, response_rx) = async_channel::unbounded(); let (response_tx, response_rx) = async_channel::unbounded();
@ -73,6 +45,7 @@ pub async fn daemon_main() -> Result<()> {
} }
}); });
tokio::try_join!(inst_job, listen_job).map(|_| ()); tokio::try_join!(inst_job, listen_job)
Ok(()) .map(|_| ())
.map_err(|e| e.into())
} }

View file

@ -7,7 +7,7 @@ use crate::daemon::{daemon_main, DaemonClient};
#[no_mangle] #[no_mangle]
pub extern "C" fn start_srv() { pub extern "C" fn start_srv() {
info!("Rust: Starting server"); info!("Starting server");
let _handle = thread::spawn(move || { let _handle = thread::spawn(move || {
let rt = Runtime::new().unwrap(); let rt = Runtime::new().unwrap();
rt.block_on(async { rt.block_on(async {
@ -20,10 +20,11 @@ pub extern "C" fn start_srv() {
rt.block_on(async { rt.block_on(async {
loop { loop {
match DaemonClient::new().await { match DaemonClient::new().await {
Ok(_) => break, Ok(..) => {
Err(_e) => { info!("Server successfully started");
// error!("Error when connecting to daemon: {}", e) break
} }
Err(e) => error!("Could not connect to server: {}", e),
} }
} }
}); });

View file

@ -1,7 +1,9 @@
use std::os::fd::IntoRawFd;
use anyhow::Result;
use super::*; use super::*;
use crate::daemon::DaemonResponse; use crate::daemon::DaemonResponse;
use anyhow::Result;
use std::os::fd::IntoRawFd;
pub async fn listen( pub async fn listen(
cmd_tx: async_channel::Sender<DaemonCommand>, cmd_tx: async_channel::Sender<DaemonCommand>,

View file

@ -6,16 +6,16 @@ use std::{
}, },
path::{Path, PathBuf}, path::{Path, PathBuf},
}; };
use tracing::info;
use crate::daemon::{DaemonCommand, DaemonResponse, DaemonResponseData};
use anyhow::{anyhow, Result}; use anyhow::{anyhow, Result};
use tokio::{ use tokio::{
io::{AsyncBufReadExt, AsyncWriteExt, BufReader}, io::{AsyncBufReadExt, AsyncWriteExt, BufReader},
net::{UnixListener, UnixStream}, net::{UnixListener, UnixStream},
}; };
use tracing::debug; use tracing::{debug, info};
use super::*; use super::*;
use crate::daemon::{DaemonCommand, DaemonResponse, DaemonResponseData};
#[cfg(not(target_vendor = "apple"))] #[cfg(not(target_vendor = "apple"))]
const UNIX_SOCKET_PATH: &str = "/run/burrow.sock"; const UNIX_SOCKET_PATH: &str = "/run/burrow.sock";
@ -36,7 +36,7 @@ fn fetch_socket_path() -> Option<PathBuf> {
for path in tries { for path in tries {
let path = PathBuf::from(path); let path = PathBuf::from(path);
if path.exists() { if path.exists() {
return Some(path); return Some(path)
} }
} }
None None

View file

@ -1,4 +1,7 @@
use anyhow::Result;
use super::*; use super::*;
use crate::daemon::DaemonResponse;
pub async fn listen( pub async fn listen(
_cmd_tx: async_channel::Sender<DaemonCommand>, _cmd_tx: async_channel::Sender<DaemonCommand>,

View file

@ -18,9 +18,9 @@ impl DaemonResponse {
} }
} }
impl Into<DaemonResponse> for DaemonResponseData { impl From<DaemonResponseData> for DaemonResponse {
fn into(self) -> DaemonResponse { fn from(val: DaemonResponseData) -> Self {
DaemonResponse::new(Ok::<DaemonResponseData, String>(self)) DaemonResponse::new(Ok::<DaemonResponseData, String>(val))
} }
} }

View file

@ -1,5 +1,5 @@
--- ---
source: burrow/src/daemon/command.rs source: burrow/src/daemon/command.rs
expression: "serde_json::to_string(&DaemonCommand::Start(DaemonStartOptions {\n tun: TunOptions { seek_utun: true, ..TunOptions::default() },\n })).unwrap()" expression: "serde_json::to_string(&DaemonCommand::Start(DaemonStartOptions {\n tun: TunOptions { ..TunOptions::default() },\n })).unwrap()"
--- ---
{"Start":{"tun":{"name":null,"no_pi":false,"tun_excl":false,"seek_utun":true,"address":null}}} {"Start":{"tun":{"name":null,"no_pi":false,"tun_excl":false,"tun_retrieve":false,"address":null}}}

View file

@ -2,4 +2,4 @@
source: burrow/src/daemon/command.rs source: burrow/src/daemon/command.rs
expression: "serde_json::to_string(&DaemonCommand::Start(DaemonStartOptions::default())).unwrap()" expression: "serde_json::to_string(&DaemonCommand::Start(DaemonStartOptions::default())).unwrap()"
--- ---
{"Start":{"tun":{"name":null,"no_pi":false,"tun_excl":false,"seek_utun":false,"address":null}}} {"Start":{"tun":{"name":null,"no_pi":false,"tun_excl":false,"tun_retrieve":false,"address":null}}}

View file

@ -1,8 +1,16 @@
#[cfg(any(target_os = "linux", target_vendor = "apple"))]
pub mod wireguard; pub mod wireguard;
#[cfg(any(target_os = "linux", target_vendor = "apple"))]
mod daemon; mod daemon;
#[cfg(any(target_os = "linux", target_vendor = "apple"))]
pub use daemon::{ pub use daemon::{
DaemonClient, DaemonCommand, DaemonResponse, DaemonResponseData, DaemonStartOptions, ServerInfo, DaemonClient,
DaemonCommand,
DaemonResponse,
DaemonResponseData,
DaemonStartOptions,
ServerInfo,
}; };
#[cfg(target_vendor = "apple")] #[cfg(target_vendor = "apple")]

View file

@ -1,5 +1,4 @@
use anyhow::{Context, Result}; use anyhow::{Context, Result};
#[cfg(any(target_os = "linux", target_vendor = "apple"))]
use clap::{Args, Parser, Subcommand}; use clap::{Args, Parser, Subcommand};
use tracing::instrument; use tracing::instrument;
use tracing_log::LogTracer; use tracing_log::LogTracer;
@ -8,12 +7,16 @@ use tracing_subscriber::{prelude::*, EnvFilter, FmtSubscriber};
#[cfg(any(target_os = "linux", target_vendor = "apple"))] #[cfg(any(target_os = "linux", target_vendor = "apple"))]
use tun::TunInterface; use tun::TunInterface;
#[cfg(any(target_os = "linux", target_vendor = "apple"))]
mod daemon; mod daemon;
#[cfg(any(target_os = "linux", target_vendor = "apple"))]
mod wireguard; mod wireguard;
#[cfg(any(target_os = "linux", target_vendor = "apple"))]
use daemon::{DaemonClient, DaemonCommand, DaemonStartOptions}; use daemon::{DaemonClient, DaemonCommand, DaemonStartOptions};
use tun::TunOptions; use tun::TunOptions;
#[cfg(any(target_os = "linux", target_vendor = "apple"))]
use crate::daemon::DaemonResponseData; use crate::daemon::DaemonResponseData;
#[derive(Parser)] #[derive(Parser)]
@ -184,6 +187,7 @@ async fn try_serverinfo() -> Result<()> {
async fn try_serverconfig() -> Result<()> { async fn try_serverconfig() -> Result<()> {
Ok(()) Ok(())
} }
#[cfg(any(target_os = "linux", target_vendor = "apple"))]
#[tokio::main(flavor = "current_thread")] #[tokio::main(flavor = "current_thread")]
async fn main() -> Result<()> { async fn main() -> Result<()> {
initialize_tracing().await?; initialize_tracing().await?;
@ -226,3 +230,8 @@ fn system_log() -> Result<Option<tracing_journald::Layer>> {
fn system_log() -> Result<Option<OsLogger>> { fn system_log() -> Result<Option<OsLogger>> {
Ok(Some(OsLogger::new("com.hackclub.burrow", "burrow-cli"))) Ok(Some(OsLogger::new("com.hackclub.burrow", "burrow-cli")))
} }
#[cfg(not(any(target_os = "linux", target_vendor = "apple")))]
pub fn main() {
eprintln!("This platform is not supported currently.")
}

View file

@ -1,13 +1,13 @@
use crate::wireguard::{Interface as WgInterface, Peer as WgPeer}; use std::{net::ToSocketAddrs, str::FromStr};
use anyhow::{anyhow, Error, Result}; use anyhow::{anyhow, Error, Result};
use base64::engine::general_purpose; use base64::{engine::general_purpose, Engine};
use base64::Engine;
use fehler::throws; use fehler::throws;
use ip_network::IpNetwork; use ip_network::IpNetwork;
use std::net::ToSocketAddrs;
use std::str::FromStr;
use x25519_dalek::{PublicKey, StaticSecret}; use x25519_dalek::{PublicKey, StaticSecret};
use crate::wireguard::{Interface as WgInterface, Peer as WgPeer};
#[throws] #[throws]
fn parse_key(string: &str) -> [u8; 32] { fn parse_key(string: &str) -> [u8; 32] {
let value = general_purpose::STANDARD.decode(string)?; let value = general_purpose::STANDARD.decode(string)?;
@ -68,12 +68,11 @@ impl TryFrom<Config> for WgInterface {
endpoint: p endpoint: p
.endpoint .endpoint
.to_socket_addrs()? .to_socket_addrs()?
.filter(|sock| sock.is_ipv4()) .find(|sock| sock.is_ipv4())
.next()
.ok_or(anyhow!("DNS Lookup Fails!"))?, .ok_or(anyhow!("DNS Lookup Fails!"))?,
preshared_key: match &p.preshared_key { preshared_key: match &p.preshared_key {
None => Ok(None), None => Ok(None),
Some(k) => parse_key(k).map(|res| Some(res)), Some(k) => parse_key(k).map(Some),
}?, }?,
allowed_ips: p allowed_ips: p
.allowed_ips .allowed_ips
@ -86,29 +85,28 @@ impl TryFrom<Config> for WgInterface {
}) })
}) })
.collect::<Result<Vec<WgPeer>>>()?; .collect::<Result<Vec<WgPeer>>>()?;
Ok(WgInterface::new(wg_peers)?) WgInterface::new(wg_peers)
} }
} }
impl Default for Config { impl Default for Config {
fn default() -> Self { fn default() -> Self {
Self{ Self {
interface: Interface{ interface: Interface {
private_key: "GNqIAOCRxjl/cicZyvkvpTklgQuUmGUIEkH7IXF/sEE=".into(), private_key: "GNqIAOCRxjl/cicZyvkvpTklgQuUmGUIEkH7IXF/sEE=".into(),
address: "10.13.13.2/24".into(), address: "10.13.13.2/24".into(),
listen_port: 51820, listen_port: 51820,
dns: Default::default(), dns: Default::default(),
mtu: Default::default() mtu: Default::default(),
}, },
peers: vec![Peer{ peers: vec![Peer {
endpoint: "wg.burrow.rs:51820".into(), endpoint: "wg.burrow.rs:51820".into(),
allowed_ips: vec!["8.8.8.8/32".into()], allowed_ips: vec!["8.8.8.8/32".into()],
public_key: "uy75leriJay0+oHLhRMpV+A5xAQ0hCJ+q7Ww81AOvT4=".into(), public_key: "uy75leriJay0+oHLhRMpV+A5xAQ0hCJ+q7Ww81AOvT4=".into(),
preshared_key: Some("s7lx/mg+reVEMnGnqeyYOQkzD86n2+gYnx1M9ygi08k=".into()), preshared_key: Some("s7lx/mg+reVEMnGnqeyYOQkzD86n2+gYnx1M9ygi08k=".into()),
persistent_keepalive: Default::default(), persistent_keepalive: Default::default(),
name: Default::default() name: Default::default(),
}] }],
} }
} }
} }

View file

@ -1,33 +1,15 @@
use std::{net::IpAddr, sync::Arc, time::Duration}; use std::{net::IpAddr, sync::Arc};
use anyhow::Error; use anyhow::Error;
use async_trait::async_trait;
use fehler::throws; use fehler::throws;
use futures::{future::join_all, FutureExt}; use futures::future::join_all;
use ip_network_table::IpNetworkTable; use ip_network_table::IpNetworkTable;
use tokio::{sync::RwLock, task::JoinHandle, time::timeout}; use tokio::sync::RwLock;
use tracing::{debug, error}; use tracing::{debug, error};
use tun::tokio::TunInterface; use tun::tokio::TunInterface;
use super::{noise::Tunnel, Peer, PeerPcb}; use super::{noise::Tunnel, Peer, PeerPcb};
#[async_trait]
pub trait PacketInterface {
async fn recv(&mut self, buf: &mut [u8]) -> Result<usize, tokio::io::Error>;
async fn send(&mut self, buf: &[u8]) -> Result<usize, tokio::io::Error>;
}
#[async_trait]
impl PacketInterface for tun::tokio::TunInterface {
async fn recv(&mut self, buf: &mut [u8]) -> Result<usize, tokio::io::Error> {
self.recv(buf).await
}
async fn send(&mut self, buf: &[u8]) -> Result<usize, tokio::io::Error> {
self.send(buf).await
}
}
struct IndexedPcbs { struct IndexedPcbs {
pcbs: Vec<Arc<PeerPcb>>, pcbs: Vec<Arc<PeerPcb>>,
allowed_ips: IpNetworkTable<usize>, allowed_ips: IpNetworkTable<usize>,
@ -44,7 +26,7 @@ impl IndexedPcbs {
pub fn insert(&mut self, pcb: PeerPcb) { pub fn insert(&mut self, pcb: PeerPcb) {
let idx: usize = self.pcbs.len(); let idx: usize = self.pcbs.len();
for allowed_ip in pcb.allowed_ips.iter() { for allowed_ip in pcb.allowed_ips.iter() {
self.allowed_ips.insert(allowed_ip.clone(), idx); self.allowed_ips.insert(*allowed_ip, idx);
} }
self.pcbs.insert(idx, Arc::new(pcb)); self.pcbs.insert(idx, Arc::new(pcb));
} }
@ -53,10 +35,6 @@ impl IndexedPcbs {
let (_, &idx) = self.allowed_ips.longest_match(addr)?; let (_, &idx) = self.allowed_ips.longest_match(addr)?;
Some(idx) Some(idx)
} }
pub async fn connect(&self, idx: usize, handle: JoinHandle<()>) {
self.pcbs[idx].handle.write().await.replace(handle);
}
} }
impl FromIterator<PeerPcb> for IndexedPcbs { impl FromIterator<PeerPcb> for IndexedPcbs {
@ -78,7 +56,7 @@ impl Interface {
pub fn new<I: IntoIterator<Item = Peer>>(peers: I) -> Self { pub fn new<I: IntoIterator<Item = Peer>>(peers: I) -> Self {
let pcbs: IndexedPcbs = peers let pcbs: IndexedPcbs = peers
.into_iter() .into_iter()
.map(|peer| PeerPcb::new(peer)) .map(PeerPcb::new)
.collect::<Result<_, _>>()?; .collect::<Result<_, _>>()?;
let pcbs = Arc::new(pcbs); let pcbs = Arc::new(pcbs);
@ -90,63 +68,53 @@ impl Interface {
} }
pub async fn run(&self) -> anyhow::Result<()> { pub async fn run(&self) -> anyhow::Result<()> {
debug!("RUN: starting interface");
let pcbs = self.pcbs.clone(); let pcbs = self.pcbs.clone();
let tun = self let tun = self
.tun .tun
.clone() .clone()
.ok_or(anyhow::anyhow!("tun interface does not exist"))?; .ok_or(anyhow::anyhow!("tun interface does not exist"))?;
log::info!("starting interface"); log::info!("Starting interface");
let outgoing = async move { let outgoing = async move {
loop { loop {
// tracing::debug!("starting loop...");
let mut buf = [0u8; 3000]; let mut buf = [0u8; 3000];
let src = { let src = {
let src = match timeout( let src = match tun.read().await.recv(&mut buf[..]).await {
Duration::from_millis(10), Ok(len) => &buf[..len],
tun.read().await.recv(&mut buf[..]), Err(e) => {
) error!("Failed to read from interface: {}", e);
.await continue
{
Ok(Ok(len)) => &buf[..len],
Ok(Err(e)) => {
error!("failed to read from interface: {}", e);
continue;
} }
Err(_would_block) => continue,
}; };
debug!("read {} bytes from interface", src.len()); debug!("Read {} bytes from interface", src.len());
debug!("bytes: {:?}", src);
src src
}; };
let dst_addr = match Tunnel::dst_address(src) { let dst_addr = match Tunnel::dst_address(src) {
Some(addr) => addr, Some(addr) => addr,
None => { None => {
tracing::debug!("no destination found"); debug!("No destination found");
continue; continue
} }
}; };
tracing::debug!("dst_addr: {}", dst_addr); debug!("Routing packet to {}", dst_addr);
debug!("src_addr: {}", Tunnel::src_address(src).unwrap());
let Some(idx) = pcbs.find(dst_addr) else { let Some(idx) = pcbs.find(dst_addr) else {
continue continue
}; };
tracing::debug!("found peer:{}", idx); debug!("Found peer:{}", idx);
match pcbs.pcbs[idx].send(src).await { match pcbs.pcbs[idx].send(src).await {
Ok(..) => { Ok(..) => {
let addr = pcbs.pcbs[idx].endpoint; let addr = pcbs.pcbs[idx].endpoint;
tracing::debug!("sent packet to peer {}", addr); debug!("Sent packet to peer {}", addr);
} }
Err(e) => { Err(e) => {
log::error!("failed to send packet {}", e); log::error!("Failed to send packet {}", e);
continue; continue
} }
}; };
} }
@ -170,14 +138,13 @@ impl Interface {
let tsk = async move { let tsk = async move {
if let Err(e) = pcb.open_if_closed().await { if let Err(e) = pcb.open_if_closed().await {
log::error!("failed to open pcb: {}", e); log::error!("failed to open pcb: {}", e);
return; return
} }
let r2 = pcb.run(tun).await; let r2 = pcb.run(tun).await;
if let Err(e) = r2 { if let Err(e) = r2 {
log::error!("failed to run pcb: {}", e); log::error!("failed to run pcb: {}", e);
return;
} else { } else {
tracing::debug!("pcb ran successfully"); debug!("pcb ran successfully");
} }
}; };
debug!("task made.."); debug!("task made..");

View file

@ -4,21 +4,8 @@ mod noise;
mod pcb; mod pcb;
mod peer; mod peer;
pub use config::Config;
pub use iface::Interface; pub use iface::Interface;
pub use pcb::PeerPcb; pub use pcb::PeerPcb;
pub use peer::Peer; pub use peer::Peer;
pub use x25519_dalek::{PublicKey, StaticSecret}; pub use x25519_dalek::{PublicKey, StaticSecret};
pub use config::Config;
const WIREGUARD_CONFIG: &str = r#"
[Interface]
# Device: Gentle Tomcat
PrivateKey = sIxpokQPnWctJKNaQ3DRdcQbL2S5OMbUrvr4bbsvTHw=
Address = 10.68.136.199/32,fc00:bbbb:bbbb:bb01::5:88c6/128
DNS = 10.64.0.1
[Peer]
public_key = EKZXvHlSDeqAjfC/m9aQR0oXfQ6Idgffa9L0DH5yaCo=
AllowedIPs = 0.0.0.0/0,::0/0
Endpoint = 146.70.173.66:51820
"#;

View file

@ -4,9 +4,7 @@
#[derive(Debug)] #[derive(Debug)]
pub enum WireGuardError { pub enum WireGuardError {
DestinationBufferTooSmall, DestinationBufferTooSmall,
IncorrectPacketLength,
UnexpectedPacket, UnexpectedPacket,
WrongPacketType,
WrongIndex, WrongIndex,
WrongKey, WrongKey,
InvalidTai64nTimestamp, InvalidTai64nTimestamp,
@ -17,7 +15,6 @@ pub enum WireGuardError {
DuplicateCounter, DuplicateCounter,
InvalidPacket, InvalidPacket,
NoCurrentSession, NoCurrentSession,
LockFailed,
ConnectionExpired, ConnectionExpired,
UnderLoad, UnderLoad,
} }

View file

@ -9,14 +9,20 @@ use std::{
use aead::{Aead, Payload}; use aead::{Aead, Payload};
use blake2::{ use blake2::{
digest::{FixedOutput, KeyInit}, digest::{FixedOutput, KeyInit},
Blake2s256, Blake2sMac, Digest, Blake2s256,
Blake2sMac,
Digest,
}; };
use chacha20poly1305::XChaCha20Poly1305; use chacha20poly1305::XChaCha20Poly1305;
use rand_core::OsRng; use rand_core::OsRng;
use ring::aead::{Aad, LessSafeKey, Nonce, UnboundKey, CHACHA20_POLY1305}; use ring::aead::{Aad, LessSafeKey, Nonce, UnboundKey, CHACHA20_POLY1305};
use super::{ use super::{
errors::WireGuardError, session::Session, x25519, HandshakeInit, HandshakeResponse, errors::WireGuardError,
session::Session,
x25519,
HandshakeInit,
HandshakeResponse,
PacketCookieReply, PacketCookieReply,
}; };
@ -203,7 +209,7 @@ impl Tai64N {
/// Parse a timestamp from a 12 byte u8 slice /// Parse a timestamp from a 12 byte u8 slice
fn parse(buf: &[u8; 12]) -> Result<Tai64N, WireGuardError> { fn parse(buf: &[u8; 12]) -> Result<Tai64N, WireGuardError> {
if buf.len() < 12 { if buf.len() < 12 {
return Err(WireGuardError::InvalidTai64nTimestamp); return Err(WireGuardError::InvalidTai64nTimestamp)
} }
let (sec_bytes, nano_bytes) = buf.split_at(std::mem::size_of::<u64>()); let (sec_bytes, nano_bytes) = buf.split_at(std::mem::size_of::<u64>());
@ -550,22 +556,19 @@ impl Handshake {
let timestamp = Tai64N::parse(&timestamp)?; let timestamp = Tai64N::parse(&timestamp)?;
if !timestamp.after(&self.last_handshake_timestamp) { if !timestamp.after(&self.last_handshake_timestamp) {
// Possibly a replay // Possibly a replay
return Err(WireGuardError::WrongTai64nTimestamp); return Err(WireGuardError::WrongTai64nTimestamp)
} }
self.last_handshake_timestamp = timestamp; self.last_handshake_timestamp = timestamp;
// initiator.hash = HASH(initiator.hash || msg.encrypted_timestamp) // initiator.hash = HASH(initiator.hash || msg.encrypted_timestamp)
hash = b2s_hash(&hash, packet.encrypted_timestamp); hash = b2s_hash(&hash, packet.encrypted_timestamp);
self.previous = std::mem::replace( self.previous = std::mem::replace(&mut self.state, HandshakeState::InitReceived {
&mut self.state, chaining_key,
HandshakeState::InitReceived { hash,
chaining_key, peer_ephemeral_public,
hash, peer_index,
peer_ephemeral_public, });
peer_index,
},
);
self.format_handshake_response(dst) self.format_handshake_response(dst)
} }
@ -666,7 +669,7 @@ impl Handshake {
let local_index = self.cookies.index; let local_index = self.cookies.index;
if packet.receiver_idx != local_index { if packet.receiver_idx != local_index {
return Err(WireGuardError::WrongIndex); return Err(WireGuardError::WrongIndex)
} }
// msg.encrypted_cookie = XAEAD(HASH(LABEL_COOKIE || responder.static_public), // msg.encrypted_cookie = XAEAD(HASH(LABEL_COOKIE || responder.static_public),
// msg.nonce, cookie, last_received_msg.mac1) // msg.nonce, cookie, last_received_msg.mac1)
@ -722,7 +725,7 @@ impl Handshake {
dst: &'a mut [u8], dst: &'a mut [u8],
) -> Result<&'a mut [u8], WireGuardError> { ) -> Result<&'a mut [u8], WireGuardError> {
if dst.len() < super::HANDSHAKE_INIT_SZ { if dst.len() < super::HANDSHAKE_INIT_SZ {
return Err(WireGuardError::DestinationBufferTooSmall); return Err(WireGuardError::DestinationBufferTooSmall)
} }
let (message_type, rest) = dst.split_at_mut(4); let (message_type, rest) = dst.split_at_mut(4);
@ -805,7 +808,7 @@ impl Handshake {
dst: &'a mut [u8], dst: &'a mut [u8],
) -> Result<(&'a mut [u8], Session), WireGuardError> { ) -> Result<(&'a mut [u8], Session), WireGuardError> {
if dst.len() < super::HANDSHAKE_RESP_SZ { if dst.len() < super::HANDSHAKE_RESP_SZ {
return Err(WireGuardError::DestinationBufferTooSmall); return Err(WireGuardError::DestinationBufferTooSmall)
} }
let state = std::mem::replace(&mut self.state, HandshakeState::None); let state = std::mem::replace(&mut self.state, HandshakeState::None);

View file

@ -45,7 +45,11 @@ const N_SESSIONS: usize = 8;
pub mod x25519 { pub mod x25519 {
pub use x25519_dalek::{ pub use x25519_dalek::{
EphemeralSecret, PublicKey, ReusableSecret, SharedSecret, StaticSecret, EphemeralSecret,
PublicKey,
ReusableSecret,
SharedSecret,
StaticSecret,
}; };
} }
@ -129,15 +133,15 @@ pub struct PacketData<'a> {
pub enum Packet<'a> { pub enum Packet<'a> {
HandshakeInit(HandshakeInit<'a>), HandshakeInit(HandshakeInit<'a>),
HandshakeResponse(HandshakeResponse<'a>), HandshakeResponse(HandshakeResponse<'a>),
PacketCookieReply(PacketCookieReply<'a>), CookieReply(PacketCookieReply<'a>),
PacketData(PacketData<'a>), Data(PacketData<'a>),
} }
impl Tunnel { impl Tunnel {
#[inline(always)] #[inline(always)]
pub fn parse_incoming_packet(src: &[u8]) -> Result<Packet, WireGuardError> { pub fn parse_incoming_packet(src: &[u8]) -> Result<Packet, WireGuardError> {
if src.len() < 4 { if src.len() < 4 {
return Err(WireGuardError::InvalidPacket); return Err(WireGuardError::InvalidPacket)
} }
// Checks the type, as well as the reserved zero fields // Checks the type, as well as the reserved zero fields
@ -159,12 +163,12 @@ impl Tunnel {
.expect("length already checked above"), .expect("length already checked above"),
encrypted_nothing: &src[44..60], encrypted_nothing: &src[44..60],
}), }),
(COOKIE_REPLY, COOKIE_REPLY_SZ) => Packet::PacketCookieReply(PacketCookieReply { (COOKIE_REPLY, COOKIE_REPLY_SZ) => Packet::CookieReply(PacketCookieReply {
receiver_idx: u32::from_le_bytes(src[4..8].try_into().unwrap()), receiver_idx: u32::from_le_bytes(src[4..8].try_into().unwrap()),
nonce: &src[8..32], nonce: &src[8..32],
encrypted_cookie: &src[32..64], encrypted_cookie: &src[32..64],
}), }),
(DATA, DATA_OVERHEAD_SZ..=std::usize::MAX) => Packet::PacketData(PacketData { (DATA, DATA_OVERHEAD_SZ..=std::usize::MAX) => Packet::Data(PacketData {
receiver_idx: u32::from_le_bytes(src[4..8].try_into().unwrap()), receiver_idx: u32::from_le_bytes(src[4..8].try_into().unwrap()),
counter: u64::from_le_bytes(src[8..16].try_into().unwrap()), counter: u64::from_le_bytes(src[8..16].try_into().unwrap()),
encrypted_encapsulated_packet: &src[16..], encrypted_encapsulated_packet: &src[16..],
@ -179,7 +183,7 @@ impl Tunnel {
pub fn dst_address(packet: &[u8]) -> Option<IpAddr> { pub fn dst_address(packet: &[u8]) -> Option<IpAddr> {
if packet.is_empty() { if packet.is_empty() {
return None; return None
} }
match packet[0] >> 4 { match packet[0] >> 4 {
@ -203,7 +207,7 @@ impl Tunnel {
pub fn src_address(packet: &[u8]) -> Option<IpAddr> { pub fn src_address(packet: &[u8]) -> Option<IpAddr> {
if packet.is_empty() { if packet.is_empty() {
return None; return None
} }
match packet[0] >> 4 { match packet[0] >> 4 {
@ -298,7 +302,7 @@ impl Tunnel {
self.timer_tick(TimerName::TimeLastDataPacketSent); self.timer_tick(TimerName::TimeLastDataPacketSent);
} }
self.tx_bytes += src.len(); self.tx_bytes += src.len();
return TunnResult::WriteToNetwork(packet); return TunnResult::WriteToNetwork(packet)
} }
// If there is no session, queue the packet for future retry // If there is no session, queue the packet for future retry
@ -322,7 +326,7 @@ impl Tunnel {
) -> TunnResult<'a> { ) -> TunnResult<'a> {
if datagram.is_empty() { if datagram.is_empty() {
// Indicates a repeated call // Indicates a repeated call
return self.send_queued_packet(dst); return self.send_queued_packet(dst)
} }
let mut cookie = [0u8; COOKIE_REPLY_SZ]; let mut cookie = [0u8; COOKIE_REPLY_SZ];
@ -333,7 +337,7 @@ impl Tunnel {
Ok(packet) => packet, Ok(packet) => packet,
Err(TunnResult::WriteToNetwork(cookie)) => { Err(TunnResult::WriteToNetwork(cookie)) => {
dst[..cookie.len()].copy_from_slice(cookie); dst[..cookie.len()].copy_from_slice(cookie);
return TunnResult::WriteToNetwork(&mut dst[..cookie.len()]); return TunnResult::WriteToNetwork(&mut dst[..cookie.len()])
} }
Err(TunnResult::Err(e)) => return TunnResult::Err(e), Err(TunnResult::Err(e)) => return TunnResult::Err(e),
_ => unreachable!(), _ => unreachable!(),
@ -350,8 +354,8 @@ impl Tunnel {
match packet { match packet {
Packet::HandshakeInit(p) => self.handle_handshake_init(p, dst), Packet::HandshakeInit(p) => self.handle_handshake_init(p, dst),
Packet::HandshakeResponse(p) => self.handle_handshake_response(p, dst), Packet::HandshakeResponse(p) => self.handle_handshake_response(p, dst),
Packet::PacketCookieReply(p) => self.handle_cookie_reply(p), Packet::CookieReply(p) => self.handle_cookie_reply(p),
Packet::PacketData(p) => self.handle_data(p, dst), Packet::Data(p) => self.handle_data(p, dst),
} }
.unwrap_or_else(TunnResult::from) .unwrap_or_else(TunnResult::from)
} }
@ -433,7 +437,7 @@ impl Tunnel {
let cur_idx = self.current; let cur_idx = self.current;
if cur_idx == new_idx { if cur_idx == new_idx {
// There is nothing to do, already using this session, this is the common case // There is nothing to do, already using this session, this is the common case
return; return
} }
if self.sessions[cur_idx % N_SESSIONS].is_none() if self.sessions[cur_idx % N_SESSIONS].is_none()
|| self.timers.session_timers[new_idx % N_SESSIONS] || self.timers.session_timers[new_idx % N_SESSIONS]
@ -479,7 +483,7 @@ impl Tunnel {
force_resend: bool, force_resend: bool,
) -> TunnResult<'a> { ) -> TunnResult<'a> {
if self.handshake.is_in_progress() && !force_resend { if self.handshake.is_in_progress() && !force_resend {
return TunnResult::Done; return TunnResult::Done
} }
if self.handshake.is_expired() { if self.handshake.is_expired() {
@ -538,7 +542,7 @@ impl Tunnel {
}; };
if computed_len > packet.len() { if computed_len > packet.len() {
return TunnResult::Err(WireGuardError::InvalidPacket); return TunnResult::Err(WireGuardError::InvalidPacket)
} }
self.timer_tick(TimerName::TimeLastDataPacketReceived); self.timer_tick(TimerName::TimeLastDataPacketReceived);

View file

@ -12,9 +12,19 @@ use ring::constant_time::verify_slices_are_equal;
use super::{ use super::{
handshake::{ handshake::{
b2s_hash, b2s_keyed_mac_16, b2s_keyed_mac_16_2, b2s_mac_24, LABEL_COOKIE, LABEL_MAC1, b2s_hash,
b2s_keyed_mac_16,
b2s_keyed_mac_16_2,
b2s_mac_24,
LABEL_COOKIE,
LABEL_MAC1,
}, },
HandshakeInit, HandshakeResponse, Packet, TunnResult, Tunnel, WireGuardError, HandshakeInit,
HandshakeResponse,
Packet,
TunnResult,
Tunnel,
WireGuardError,
}; };
const COOKIE_REFRESH: u64 = 128; // Use 128 and not 120 so the compiler can optimize out the division const COOKIE_REFRESH: u64 = 128; // Use 128 and not 120 so the compiler can optimize out the division
@ -126,7 +136,7 @@ impl RateLimiter {
dst: &'a mut [u8], dst: &'a mut [u8],
) -> Result<&'a mut [u8], WireGuardError> { ) -> Result<&'a mut [u8], WireGuardError> {
if dst.len() < super::COOKIE_REPLY_SZ { if dst.len() < super::COOKIE_REPLY_SZ {
return Err(WireGuardError::DestinationBufferTooSmall); return Err(WireGuardError::DestinationBufferTooSmall)
} }
let (message_type, rest) = dst.split_at_mut(4); let (message_type, rest) = dst.split_at_mut(4);
@ -192,7 +202,7 @@ impl RateLimiter {
let cookie_packet = self let cookie_packet = self
.format_cookie_reply(sender_idx, cookie, mac1, dst) .format_cookie_reply(sender_idx, cookie, mac1, dst)
.map_err(TunnResult::Err)?; .map_err(TunnResult::Err)?;
return Err(TunnResult::WriteToNetwork(cookie_packet)); return Err(TunnResult::WriteToNetwork(cookie_packet))
} }
} }
} }

View file

@ -88,11 +88,11 @@ impl ReceivingKeyCounterValidator {
fn will_accept(&self, counter: u64) -> Result<(), WireGuardError> { fn will_accept(&self, counter: u64) -> Result<(), WireGuardError> {
if counter >= self.next { if counter >= self.next {
// As long as the counter is growing no replay took place for sure // As long as the counter is growing no replay took place for sure
return Ok(()); return Ok(())
} }
if counter + N_BITS < self.next { if counter + N_BITS < self.next {
// Drop if too far back // Drop if too far back
return Err(WireGuardError::InvalidCounter); return Err(WireGuardError::InvalidCounter)
} }
if !self.check_bit(counter) { if !self.check_bit(counter) {
Ok(()) Ok(())
@ -107,22 +107,22 @@ impl ReceivingKeyCounterValidator {
fn mark_did_receive(&mut self, counter: u64) -> Result<(), WireGuardError> { fn mark_did_receive(&mut self, counter: u64) -> Result<(), WireGuardError> {
if counter + N_BITS < self.next { if counter + N_BITS < self.next {
// Drop if too far back // Drop if too far back
return Err(WireGuardError::InvalidCounter); return Err(WireGuardError::InvalidCounter)
} }
if counter == self.next { if counter == self.next {
// Usually the packets arrive in order, in that case we simply mark the bit and // Usually the packets arrive in order, in that case we simply mark the bit and
// increment the counter // increment the counter
self.set_bit(counter); self.set_bit(counter);
self.next += 1; self.next += 1;
return Ok(()); return Ok(())
} }
if counter < self.next { if counter < self.next {
// A packet arrived out of order, check if it is valid, and mark // A packet arrived out of order, check if it is valid, and mark
if self.check_bit(counter) { if self.check_bit(counter) {
return Err(WireGuardError::InvalidCounter); return Err(WireGuardError::InvalidCounter)
} }
self.set_bit(counter); self.set_bit(counter);
return Ok(()); return Ok(())
} }
// Packets where dropped, or maybe reordered, skip them and mark unused // Packets where dropped, or maybe reordered, skip them and mark unused
if counter - self.next >= N_BITS { if counter - self.next >= N_BITS {
@ -247,7 +247,7 @@ impl Session {
panic!("The destination buffer is too small"); panic!("The destination buffer is too small");
} }
if packet.receiver_idx != self.receiving_index { if packet.receiver_idx != self.receiving_index {
return Err(WireGuardError::WrongIndex); return Err(WireGuardError::WrongIndex)
} }
// Don't reuse counters, in case this is a replay attack we want to quickly // Don't reuse counters, in case this is a replay attack we want to quickly
// check the counter without running expensive decryption // check the counter without running expensive decryption

View file

@ -190,7 +190,7 @@ impl Tunnel {
{ {
if self.handshake.is_expired() { if self.handshake.is_expired() {
return TunnResult::Err(WireGuardError::ConnectionExpired); return TunnResult::Err(WireGuardError::ConnectionExpired)
} }
// Clear cookie after COOKIE_EXPIRATION_TIME // Clear cookie after COOKIE_EXPIRATION_TIME
@ -206,7 +206,7 @@ impl Tunnel {
tracing::error!("CONNECTION_EXPIRED(REJECT_AFTER_TIME * 3)"); tracing::error!("CONNECTION_EXPIRED(REJECT_AFTER_TIME * 3)");
self.handshake.set_expired(); self.handshake.set_expired();
self.clear_all(); self.clear_all();
return TunnResult::Err(WireGuardError::ConnectionExpired); return TunnResult::Err(WireGuardError::ConnectionExpired)
} }
if let Some(time_init_sent) = self.handshake.timer() { if let Some(time_init_sent) = self.handshake.timer() {
@ -219,7 +219,7 @@ impl Tunnel {
tracing::error!("CONNECTION_EXPIRED(REKEY_ATTEMPT_TIME)"); tracing::error!("CONNECTION_EXPIRED(REKEY_ATTEMPT_TIME)");
self.handshake.set_expired(); self.handshake.set_expired();
self.clear_all(); self.clear_all();
return TunnResult::Err(WireGuardError::ConnectionExpired); return TunnResult::Err(WireGuardError::ConnectionExpired)
} }
if time_init_sent.elapsed() >= REKEY_TIMEOUT { if time_init_sent.elapsed() >= REKEY_TIMEOUT {
@ -299,11 +299,11 @@ impl Tunnel {
} }
if handshake_initiation_required { if handshake_initiation_required {
return self.format_handshake_initiation(dst, true); return self.format_handshake_initiation(dst, true)
} }
if keepalive_required { if keepalive_required {
return self.encapsulate(&[], dst); return self.encapsulate(&[], dst)
} }
TunnResult::Done TunnResult::Done

View file

@ -1,10 +1,6 @@
use std::{ use std::{net::SocketAddr, sync::Arc};
cell::{Cell, RefCell},
net::SocketAddr,
sync::Arc,
};
use anyhow::{anyhow, Error}; use anyhow::Error;
use fehler::throws; use fehler::throws;
use ip_network::IpNetwork; use ip_network::IpNetwork;
use rand::random; use rand::random;
@ -74,7 +70,7 @@ impl PeerPcb {
Ok(l) => l, Ok(l) => l,
Err(e) => { Err(e) => {
log::error!("{}: error reading from socket: {:?}", rid, e); log::error!("{}: error reading from socket: {:?}", rid, e);
continue; continue
} }
}; };
let mut res_dat = &res_buf[..len]; let mut res_dat = &res_buf[..len];
@ -90,7 +86,7 @@ impl PeerPcb {
TunnResult::Done => break, TunnResult::Done => break,
TunnResult::Err(e) => { TunnResult::Err(e) => {
tracing::error!(message = "Decapsulate error", error = ?e); tracing::error!(message = "Decapsulate error", error = ?e);
break; break
} }
TunnResult::WriteToNetwork(packet) => { TunnResult::WriteToNetwork(packet) => {
tracing::debug!("WriteToNetwork: {:?}", packet); tracing::debug!("WriteToNetwork: {:?}", packet);
@ -98,17 +94,17 @@ impl PeerPcb {
socket.send(packet).await?; socket.send(packet).await?;
tracing::debug!("WriteToNetwork done"); tracing::debug!("WriteToNetwork done");
res_dat = &[]; res_dat = &[];
continue; continue
} }
TunnResult::WriteToTunnelV4(packet, addr) => { TunnResult::WriteToTunnelV4(packet, addr) => {
tracing::debug!("WriteToTunnelV4: {:?}, {:?}", packet, addr); tracing::debug!("WriteToTunnelV4: {:?}, {:?}", packet, addr);
tun_interface.read().await.send(packet).await?; tun_interface.read().await.send(packet).await?;
break; break
} }
TunnResult::WriteToTunnelV6(packet, addr) => { TunnResult::WriteToTunnelV6(packet, addr) => {
tracing::debug!("WriteToTunnelV6: {:?}, {:?}", packet, addr); tracing::debug!("WriteToTunnelV6: {:?}, {:?}", packet, addr);
tun_interface.read().await.send(packet).await?; tun_interface.read().await.send(packet).await?;
break; break
} }
} }
} }

View file

@ -39,5 +39,5 @@ anyhow = "1.0"
bindgen = "0.65" bindgen = "0.65"
reqwest = { version = "0.11", features = ["native-tls"] } reqwest = { version = "0.11", features = ["native-tls"] }
ssri = { version = "9.0", default-features = false } ssri = { version = "9.0", default-features = false }
tokio = { version = "1.28", features = ["rt"] } tokio = { version = "1.28", features = ["rt", "macros"] }
zip = { version = "0.6", features = ["deflate"] } zip = { version = "0.6", features = ["deflate"] }

View file

@ -26,7 +26,7 @@ async fn generate(out_dir: &std::path::Path) -> anyhow::Result<()> {
println!("cargo:rerun-if-changed={}", binary_path.to_str().unwrap()); println!("cargo:rerun-if-changed={}", binary_path.to_str().unwrap());
if let (Ok(..), Ok(..)) = (File::open(&bindings_path), File::open(&binary_path)) { if let (Ok(..), Ok(..)) = (File::open(&bindings_path), File::open(&binary_path)) {
return Ok(()); return Ok(())
}; };
let archive = download(out_dir) let archive = download(out_dir)

View file

@ -2,6 +2,8 @@ use std::io::Error;
use fehler::throws; use fehler::throws;
#[cfg(any(target_os = "linux", target_vendor = "apple"))]
#[cfg(feature = "tokio")]
use super::tokio::TunInterface; use super::tokio::TunInterface;
#[derive(Debug, Clone, Default)] #[derive(Debug, Clone, Default)]
@ -47,6 +49,8 @@ impl TunOptions {
self self
} }
#[cfg(any(target_os = "linux", target_vendor = "apple"))]
#[cfg(feature = "tokio")]
#[throws] #[throws]
pub fn open(self) -> TunInterface { pub fn open(self) -> TunInterface {
let ti = super::TunInterface::new_with_options(self)?; let ti = super::TunInterface::new_with_options(self)?;

View file

@ -34,7 +34,7 @@ impl TunInterface {
Ok(result) => return result, Ok(result) => return result,
Err(_would_block) => { Err(_would_block) => {
tracing::debug!("WouldBlock"); tracing::debug!("WouldBlock");
continue; continue
} }
} }
} }

View file

@ -1,6 +1,6 @@
use std::{ use std::{
io::{Error, IoSlice}, io::{Error, IoSlice},
mem::{self, ManuallyDrop}, mem,
net::{Ipv4Addr, SocketAddrV4}, net::{Ipv4Addr, SocketAddrV4},
os::fd::{AsRawFd, FromRawFd, RawFd}, os::fd::{AsRawFd, FromRawFd, RawFd},
}; };

View file

@ -2,11 +2,20 @@ use std::mem;
use libc::{c_char, c_int, c_short, c_uint, c_ulong, sockaddr}; use libc::{c_char, c_int, c_short, c_uint, c_ulong, sockaddr};
pub use libc::{ pub use libc::{
c_void, sockaddr_ctl, sockaddr_in, socklen_t, AF_SYSTEM, AF_SYS_CONTROL, IFNAMSIZ, c_void,
sockaddr_ctl,
sockaddr_in,
socklen_t,
AF_SYSTEM,
AF_SYS_CONTROL,
IFNAMSIZ,
SYSPROTO_CONTROL, SYSPROTO_CONTROL,
}; };
use nix::{ use nix::{
ioctl_read_bad, ioctl_readwrite, ioctl_write_ptr_bad, request_code_readwrite, ioctl_read_bad,
ioctl_readwrite,
ioctl_write_ptr_bad,
request_code_readwrite,
request_code_write, request_code_write,
}; };

View file

@ -13,6 +13,7 @@ use fehler::throws;
use libc::in6_ifreq; use libc::in6_ifreq;
use socket2::{Domain, SockAddr, Socket, Type}; use socket2::{Domain, SockAddr, Socket, Type};
use tracing::{info, instrument}; use tracing::{info, instrument};
use super::{ifname_to_string, string_to_ifname}; use super::{ifname_to_string, string_to_ifname};
use crate::TunOptions; use crate::TunOptions;

View file

@ -1,5 +1,5 @@
use std::{ use std::{
io::{Error, Read}, io::Error,
mem::MaybeUninit, mem::MaybeUninit,
os::fd::{AsRawFd, FromRawFd, IntoRawFd, RawFd}, os::fd::{AsRawFd, FromRawFd, IntoRawFd, RawFd},
}; };
@ -51,7 +51,7 @@ impl TunInterface {
let mut tmp_buf = [MaybeUninit::uninit(); 1500]; let mut tmp_buf = [MaybeUninit::uninit(); 1500];
let len = self.socket.recv(&mut tmp_buf)?; let len = self.socket.recv(&mut tmp_buf)?;
let result_buf = unsafe { assume_init(&tmp_buf[4..len]) }; let result_buf = unsafe { assume_init(&tmp_buf[4..len]) };
buf[..len - 4].copy_from_slice(&result_buf); buf[..len - 4].copy_from_slice(result_buf);
len - 4 len - 4
} }

View file

@ -11,7 +11,7 @@ fn tst_read() {
// This test is interactive, you need to send a packet to any server through // This test is interactive, you need to send a packet to any server through
// 192.168.1.10 EG. `sudo route add 8.8.8.8 192.168.1.10`, // 192.168.1.10 EG. `sudo route add 8.8.8.8 192.168.1.10`,
//`dig @8.8.8.8 hackclub.com` //`dig @8.8.8.8 hackclub.com`
let mut tun = TunInterface::new()?; let tun = TunInterface::new()?;
println!("tun name: {:?}", tun.name()?); println!("tun name: {:?}", tun.name()?);
tun.set_ipv4_addr(Ipv4Addr::from([192, 168, 1, 10]))?; tun.set_ipv4_addr(Ipv4Addr::from([192, 168, 1, 10]))?;
println!("tun ip: {:?}", tun.ipv4_addr()?); println!("tun ip: {:?}", tun.ipv4_addr()?);