Compare commits
No commits in common. "1b39eca069edb3d48b27b19971e6627b816e6697" and "a2e93278c12351fd4c9e24d1615027aeca003ed3" have entirely different histories.
1b39eca069
...
a2e93278c1
44 changed files with 314 additions and 4888 deletions
16
.github/workflows/build-rust.yml
vendored
16
.github/workflows/build-rust.yml
vendored
|
|
@ -17,24 +17,21 @@ jobs:
|
||||||
platform: Linux
|
platform: Linux
|
||||||
packages:
|
packages:
|
||||||
- gcc-aarch64-linux-gnu
|
- gcc-aarch64-linux-gnu
|
||||||
test-targets:
|
|
||||||
- x86_64-unknown-linux-gnu
|
|
||||||
targets:
|
targets:
|
||||||
|
- x86_64-unknown-linux-gnu
|
||||||
- aarch64-unknown-linux-gnu
|
- aarch64-unknown-linux-gnu
|
||||||
- os: macos-12
|
- os: macos-12
|
||||||
platform: macOS
|
platform: macOS
|
||||||
test-targets:
|
|
||||||
- x86_64-apple-darwin
|
|
||||||
targets:
|
targets:
|
||||||
|
- x86_64-apple-darwin
|
||||||
- aarch64-apple-darwin
|
- aarch64-apple-darwin
|
||||||
- aarch64-apple-ios
|
- aarch64-apple-ios
|
||||||
- aarch64-apple-ios-sim
|
- aarch64-apple-ios-sim
|
||||||
- x86_64-apple-ios
|
- x86_64-apple-ios
|
||||||
- os: windows-2022
|
- os: windows-2022
|
||||||
platform: Windows
|
platform: Windows
|
||||||
test-targets:
|
|
||||||
- x86_64-pc-windows-msvc
|
|
||||||
targets:
|
targets:
|
||||||
|
- x86_64-pc-windows-msvc
|
||||||
- aarch64-pc-windows-msvc
|
- aarch64-pc-windows-msvc
|
||||||
runs-on: ${{ matrix.os }}
|
runs-on: ${{ matrix.os }}
|
||||||
env:
|
env:
|
||||||
|
|
@ -60,11 +57,6 @@ jobs:
|
||||||
toolchain: stable
|
toolchain: stable
|
||||||
components: rustfmt
|
components: rustfmt
|
||||||
targets: ${{ join(matrix.targets, ', ') }}
|
targets: ${{ join(matrix.targets, ', ') }}
|
||||||
- name: Setup Rust Cache
|
|
||||||
uses: Swatinem/rust-cache@v2
|
|
||||||
- name: Build
|
- name: Build
|
||||||
shell: bash
|
shell: bash
|
||||||
run: cargo build --verbose --workspace --all-features --target ${{ join(matrix.targets, ' --target ') }} --target ${{ join(matrix.test-targets, ' --target ') }}
|
run: cargo build --verbose --workspace --all-features --target ${{ join(matrix.targets, ' --target ') }}
|
||||||
- name: Test
|
|
||||||
shell: bash
|
|
||||||
run: cargo test --verbose --workspace --all-features --target ${{ join(matrix.test-targets, ' --target ') }}
|
|
||||||
|
|
|
||||||
9
.vscode/settings.json
vendored
9
.vscode/settings.json
vendored
|
|
@ -8,6 +8,15 @@
|
||||||
"editor.acceptSuggestionOnEnter": "on",
|
"editor.acceptSuggestionOnEnter": "on",
|
||||||
"rust-analyzer.restartServerOnConfigChange": true,
|
"rust-analyzer.restartServerOnConfigChange": true,
|
||||||
"rust-analyzer.cargo.features": "all",
|
"rust-analyzer.cargo.features": "all",
|
||||||
|
"rust-analyzer.check.overrideCommand": [
|
||||||
|
"cargo",
|
||||||
|
"clippy",
|
||||||
|
"--fix",
|
||||||
|
"--workspace",
|
||||||
|
"--message-format=json",
|
||||||
|
"--all-targets",
|
||||||
|
"--allow-dirty"
|
||||||
|
],
|
||||||
"[rust]": {
|
"[rust]": {
|
||||||
"editor.defaultFormatter": "rust-lang.rust-analyzer",
|
"editor.defaultFormatter": "rust-lang.rust-analyzer",
|
||||||
}
|
}
|
||||||
|
|
|
||||||
|
|
@ -1,22 +1,7 @@
|
||||||
import libburrow
|
|
||||||
import NetworkExtension
|
import NetworkExtension
|
||||||
import OSLog
|
|
||||||
|
|
||||||
class PacketTunnelProvider: NEPacketTunnelProvider {
|
class PacketTunnelProvider: NEPacketTunnelProvider {
|
||||||
let logger = Logger(subsystem: "com.hackclub.burrow", category: "General")
|
|
||||||
override func startTunnel(options: [String: NSObject]?, completionHandler: @escaping (Error?) -> Void) {
|
override func startTunnel(options: [String: NSObject]?, completionHandler: @escaping (Error?) -> Void) {
|
||||||
let fild = libburrow.retrieve()
|
|
||||||
if fild == -1 {
|
|
||||||
// Not sure if this is the right way to return an error
|
|
||||||
logger.error("Failed to retrieve file descriptor for burrow.")
|
|
||||||
let err = NSError(
|
|
||||||
domain: "com.hackclub.burrow",
|
|
||||||
code: 1_010,
|
|
||||||
userInfo: [NSLocalizedDescriptionKey: "Failed to find TunInterface"]
|
|
||||||
)
|
|
||||||
completionHandler(err)
|
|
||||||
}
|
|
||||||
logger.info("fd: \(fild)")
|
|
||||||
completionHandler(nil)
|
completionHandler(nil)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
|
||||||
|
|
@ -1 +1 @@
|
||||||
int retrieve();
|
|
||||||
|
|
|
||||||
607
Cargo.lock
generated
607
Cargo.lock
generated
File diff suppressed because it is too large
Load diff
|
|
@ -17,9 +17,9 @@ Apple/ # Xcode project for burrow on macOS and iOS
|
||||||
burrow/ # Higher-level API library for tun and tun-async
|
burrow/ # Higher-level API library for tun and tun-async
|
||||||
tun/ # Low-level interface to OS networking
|
tun/ # Low-level interface to OS networking
|
||||||
src/
|
src/
|
||||||
tokio/ # Async/Tokio code
|
|
||||||
unix/ # macOS and Linux code
|
unix/ # macOS and Linux code
|
||||||
windows/ # Windows networking code
|
windows/ # Windows networking code
|
||||||
|
tun-async/ # Async interface to tun
|
||||||
```
|
```
|
||||||
|
|
||||||
## Installation
|
## Installation
|
||||||
|
|
|
||||||
|
|
@ -7,23 +7,14 @@ edition = "2021"
|
||||||
crate-type = ["lib", "staticlib"]
|
crate-type = ["lib", "staticlib"]
|
||||||
|
|
||||||
[dependencies]
|
[dependencies]
|
||||||
anyhow = "1.0"
|
tokio = { version = "1.21", features = ["rt", "macros"] }
|
||||||
tokio = { version = "1.21", features = ["rt", "macros", "sync", "io-util"] }
|
tun = { version = "0.1", path = "../tun" }
|
||||||
tun = { version = "0.1", path = "../tun", features = ["serde"] }
|
|
||||||
clap = { version = "4.3.2", features = ["derive"] }
|
clap = { version = "4.3.2", features = ["derive"] }
|
||||||
tracing = "0.1"
|
|
||||||
tracing-log = "0.1"
|
|
||||||
tracing-journald = "0.3"
|
|
||||||
tracing-oslog = {git = "https://github.com/Stormshield-robinc/tracing-oslog"}
|
|
||||||
tracing-subscriber = "0.3"
|
|
||||||
env_logger = "0.10"
|
env_logger = "0.10"
|
||||||
log = "0.4"
|
log = "0.4"
|
||||||
serde = { version = "1", features = ["derive"] }
|
|
||||||
serde_json = "1"
|
|
||||||
|
|
||||||
[target.'cfg(target_os = "linux")'.dependencies]
|
[target.'cfg(target_os = "linux")'.dependencies]
|
||||||
caps = "0.5.5"
|
caps = "0.5.5"
|
||||||
libsystemd = "0.6"
|
|
||||||
|
|
||||||
[target.'cfg(target_vendor = "apple")'.dependencies]
|
[target.'cfg(target_vendor = "apple")'.dependencies]
|
||||||
nix = { version = "0.26.2" }
|
nix = { version = "0.26.2" }
|
||||||
|
|
|
||||||
|
|
@ -1,66 +0,0 @@
|
||||||
[package]
|
|
||||||
name = "boringtun"
|
|
||||||
description = "an implementation of the WireGuard® protocol designed for portability and speed"
|
|
||||||
version = "0.6.0"
|
|
||||||
authors = [
|
|
||||||
"Noah Kennedy <nkennedy@cloudflare.com>",
|
|
||||||
"Andy Grover <agrover@cloudflare.com>",
|
|
||||||
"Jeff Hiner <jhiner@cloudflare.com>",
|
|
||||||
]
|
|
||||||
license = "BSD-3-Clause"
|
|
||||||
repository = "https://github.com/cloudflare/boringtun"
|
|
||||||
documentation = "https://docs.rs/boringtun/0.5.2/boringtun/"
|
|
||||||
edition = "2018"
|
|
||||||
|
|
||||||
[features]
|
|
||||||
default = []
|
|
||||||
device = ["socket2", "thiserror"]
|
|
||||||
jni-bindings = ["ffi-bindings", "jni"]
|
|
||||||
ffi-bindings = ["tracing-subscriber"]
|
|
||||||
# mocks std::time::Instant with mock_instant
|
|
||||||
mock-instant = ["mock_instant"]
|
|
||||||
|
|
||||||
[workspace]
|
|
||||||
|
|
||||||
[dependencies]
|
|
||||||
base64 = "0.13"
|
|
||||||
hex = "0.4"
|
|
||||||
untrusted = "0.9.0"
|
|
||||||
libc = "0.2"
|
|
||||||
parking_lot = "0.12"
|
|
||||||
tracing = "0.1.29"
|
|
||||||
tracing-subscriber = { version = "0.3", features = ["fmt"], optional = true }
|
|
||||||
ip_network = "0.4.1"
|
|
||||||
ip_network_table = "0.2.0"
|
|
||||||
ring = "0.16"
|
|
||||||
x25519-dalek = { version = "2.0.0", features = [
|
|
||||||
"reusable_secrets",
|
|
||||||
"static_secrets",
|
|
||||||
] }
|
|
||||||
rand_core = { version = "0.6.3", features = ["getrandom"] }
|
|
||||||
chacha20poly1305 = "0.10.0-pre.1"
|
|
||||||
aead = "0.5.0-pre.2"
|
|
||||||
blake2 = "0.10"
|
|
||||||
hmac = "0.12"
|
|
||||||
jni = { version = "0.19.0", optional = true }
|
|
||||||
mock_instant = { version = "0.2", optional = true }
|
|
||||||
socket2 = { version = "0.4.7", features = ["all"], optional = true }
|
|
||||||
thiserror = { version = "1", optional = true }
|
|
||||||
|
|
||||||
[target.'cfg(unix)'.dependencies]
|
|
||||||
nix = { version = "0.25", default-features = false, features = [
|
|
||||||
"time",
|
|
||||||
"user",
|
|
||||||
] }
|
|
||||||
|
|
||||||
[dev-dependencies]
|
|
||||||
etherparse = "0.12"
|
|
||||||
tracing-subscriber = "0.3"
|
|
||||||
criterion = { version = "0.3.5", features = ["html_reports"] }
|
|
||||||
|
|
||||||
[lib]
|
|
||||||
crate-type = ["staticlib", "cdylib", "rlib"]
|
|
||||||
|
|
||||||
[[bench]]
|
|
||||||
name = "crypto_benches"
|
|
||||||
harness = false
|
|
||||||
|
|
@ -1,902 +0,0 @@
|
||||||
// Copyright (c) 2019 Cloudflare, Inc. All rights reserved.
|
|
||||||
// SPDX-License-Identifier: BSD-3-Clause
|
|
||||||
|
|
||||||
pub mod allowed_ips;
|
|
||||||
pub mod api;
|
|
||||||
mod dev_lock;
|
|
||||||
pub mod drop_privileges;
|
|
||||||
#[cfg(test)]
|
|
||||||
mod integration_tests;
|
|
||||||
pub mod peer;
|
|
||||||
|
|
||||||
#[cfg(any(target_os = "macos", target_os = "ios"))]
|
|
||||||
#[path = "kqueue.rs"]
|
|
||||||
pub mod poll;
|
|
||||||
|
|
||||||
#[cfg(target_os = "linux")]
|
|
||||||
#[path = "epoll.rs"]
|
|
||||||
pub mod poll;
|
|
||||||
|
|
||||||
#[cfg(any(target_os = "macos", target_os = "ios"))]
|
|
||||||
#[path = "tun_darwin.rs"]
|
|
||||||
pub mod tun;
|
|
||||||
|
|
||||||
#[cfg(target_os = "linux")]
|
|
||||||
#[path = "tun_linux.rs"]
|
|
||||||
pub mod tun;
|
|
||||||
|
|
||||||
use std::collections::HashMap;
|
|
||||||
use std::io::{self, Write as _};
|
|
||||||
use std::mem::MaybeUninit;
|
|
||||||
use std::net::{IpAddr, Ipv4Addr, Ipv6Addr, SocketAddr, SocketAddrV4, SocketAddrV6};
|
|
||||||
use std::os::unix::io::AsRawFd;
|
|
||||||
use std::sync::atomic::{AtomicUsize, Ordering};
|
|
||||||
use std::sync::Arc;
|
|
||||||
use std::thread;
|
|
||||||
use std::thread::JoinHandle;
|
|
||||||
|
|
||||||
use crate::noise::errors::WireGuardError;
|
|
||||||
use crate::noise::handshake::parse_handshake_anon;
|
|
||||||
use crate::noise::rate_limiter::RateLimiter;
|
|
||||||
use crate::noise::{Packet, Tunn, TunnResult};
|
|
||||||
use crate::x25519;
|
|
||||||
use allowed_ips::AllowedIps;
|
|
||||||
use parking_lot::Mutex;
|
|
||||||
use peer::{AllowedIP, Peer};
|
|
||||||
use poll::{EventPoll, EventRef, WaitResult};
|
|
||||||
use rand_core::{OsRng, RngCore};
|
|
||||||
use socket2::{Domain, Protocol, Type};
|
|
||||||
use tun::TunSocket;
|
|
||||||
|
|
||||||
use dev_lock::{Lock, LockReadGuard};
|
|
||||||
|
|
||||||
const HANDSHAKE_RATE_LIMIT: u64 = 100; // The number of handshakes per second we can tolerate before using cookies
|
|
||||||
|
|
||||||
const MAX_UDP_SIZE: usize = (1 << 16) - 1;
|
|
||||||
const MAX_ITR: usize = 100; // Number of packets to handle per handler call
|
|
||||||
|
|
||||||
#[derive(Debug, thiserror::Error)]
|
|
||||||
pub enum Error {
|
|
||||||
#[error("i/o error: {0}")]
|
|
||||||
IoError(#[from] io::Error),
|
|
||||||
#[error("{0}")]
|
|
||||||
Socket(io::Error),
|
|
||||||
#[error("{0}")]
|
|
||||||
Bind(String),
|
|
||||||
#[error("{0}")]
|
|
||||||
FCntl(io::Error),
|
|
||||||
#[error("{0}")]
|
|
||||||
EventQueue(io::Error),
|
|
||||||
#[error("{0}")]
|
|
||||||
IOCtl(io::Error),
|
|
||||||
#[error("{0}")]
|
|
||||||
Connect(String),
|
|
||||||
#[error("{0}")]
|
|
||||||
SetSockOpt(String),
|
|
||||||
#[error("Invalid tunnel name")]
|
|
||||||
InvalidTunnelName,
|
|
||||||
#[cfg(any(target_os = "macos", target_os = "ios"))]
|
|
||||||
#[error("{0}")]
|
|
||||||
GetSockOpt(io::Error),
|
|
||||||
#[error("{0}")]
|
|
||||||
GetSockName(String),
|
|
||||||
#[cfg(target_os = "linux")]
|
|
||||||
#[error("{0}")]
|
|
||||||
Timer(io::Error),
|
|
||||||
#[error("iface read: {0}")]
|
|
||||||
IfaceRead(io::Error),
|
|
||||||
#[error("{0}")]
|
|
||||||
DropPrivileges(String),
|
|
||||||
#[error("API socket error: {0}")]
|
|
||||||
ApiSocket(io::Error),
|
|
||||||
}
|
|
||||||
|
|
||||||
// What the event loop should do after a handler returns
|
|
||||||
enum Action {
|
|
||||||
Continue, // Continue the loop
|
|
||||||
Yield, // Yield the read lock and acquire it again
|
|
||||||
Exit, // Stop the loop
|
|
||||||
}
|
|
||||||
|
|
||||||
// Event handler function
|
|
||||||
type Handler = Box<dyn Fn(&mut LockReadGuard<Device>, &mut ThreadData) -> Action + Send + Sync>;
|
|
||||||
|
|
||||||
pub struct DeviceHandle {
|
|
||||||
device: Arc<Lock<Device>>, // The interface this handle owns
|
|
||||||
threads: Vec<JoinHandle<()>>,
|
|
||||||
}
|
|
||||||
|
|
||||||
#[derive(Debug, Clone, Copy)]
|
|
||||||
pub struct DeviceConfig {
|
|
||||||
pub n_threads: usize,
|
|
||||||
pub use_connected_socket: bool,
|
|
||||||
#[cfg(target_os = "linux")]
|
|
||||||
pub use_multi_queue: bool,
|
|
||||||
#[cfg(target_os = "linux")]
|
|
||||||
pub uapi_fd: i32,
|
|
||||||
}
|
|
||||||
|
|
||||||
impl Default for DeviceConfig {
|
|
||||||
fn default() -> Self {
|
|
||||||
DeviceConfig {
|
|
||||||
n_threads: 4,
|
|
||||||
use_connected_socket: true,
|
|
||||||
#[cfg(target_os = "linux")]
|
|
||||||
use_multi_queue: true,
|
|
||||||
#[cfg(target_os = "linux")]
|
|
||||||
uapi_fd: -1,
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
pub struct Device {
|
|
||||||
key_pair: Option<(x25519::StaticSecret, x25519::PublicKey)>,
|
|
||||||
queue: Arc<EventPoll<Handler>>,
|
|
||||||
|
|
||||||
listen_port: u16,
|
|
||||||
fwmark: Option<u32>,
|
|
||||||
|
|
||||||
iface: Arc<TunSocket>,
|
|
||||||
udp4: Option<socket2::Socket>,
|
|
||||||
udp6: Option<socket2::Socket>,
|
|
||||||
|
|
||||||
yield_notice: Option<EventRef>,
|
|
||||||
exit_notice: Option<EventRef>,
|
|
||||||
|
|
||||||
peers: HashMap<x25519::PublicKey, Arc<Mutex<Peer>>>,
|
|
||||||
peers_by_ip: AllowedIps<Arc<Mutex<Peer>>>,
|
|
||||||
peers_by_idx: HashMap<u32, Arc<Mutex<Peer>>>,
|
|
||||||
next_index: IndexLfsr,
|
|
||||||
|
|
||||||
config: DeviceConfig,
|
|
||||||
|
|
||||||
cleanup_paths: Vec<String>,
|
|
||||||
|
|
||||||
mtu: AtomicUsize,
|
|
||||||
|
|
||||||
rate_limiter: Option<Arc<RateLimiter>>,
|
|
||||||
|
|
||||||
#[cfg(target_os = "linux")]
|
|
||||||
uapi_fd: i32,
|
|
||||||
}
|
|
||||||
|
|
||||||
struct ThreadData {
|
|
||||||
iface: Arc<TunSocket>,
|
|
||||||
src_buf: [u8; MAX_UDP_SIZE],
|
|
||||||
dst_buf: [u8; MAX_UDP_SIZE],
|
|
||||||
}
|
|
||||||
|
|
||||||
impl DeviceHandle {
|
|
||||||
pub fn new(name: &str, config: DeviceConfig) -> Result<DeviceHandle, Error> {
|
|
||||||
let n_threads = config.n_threads;
|
|
||||||
let mut wg_interface = Device::new(name, config)?;
|
|
||||||
wg_interface.open_listen_socket(0)?; // Start listening on a random port
|
|
||||||
|
|
||||||
let interface_lock = Arc::new(Lock::new(wg_interface));
|
|
||||||
|
|
||||||
let mut threads = vec![];
|
|
||||||
|
|
||||||
for i in 0..n_threads {
|
|
||||||
threads.push({
|
|
||||||
let dev = Arc::clone(&interface_lock);
|
|
||||||
thread::spawn(move || DeviceHandle::event_loop(i, &dev))
|
|
||||||
});
|
|
||||||
}
|
|
||||||
|
|
||||||
Ok(DeviceHandle {
|
|
||||||
device: interface_lock,
|
|
||||||
threads,
|
|
||||||
})
|
|
||||||
}
|
|
||||||
|
|
||||||
pub fn wait(&mut self) {
|
|
||||||
while let Some(thread) = self.threads.pop() {
|
|
||||||
thread.join().unwrap();
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
pub fn clean(&mut self) {
|
|
||||||
for path in &self.device.read().cleanup_paths {
|
|
||||||
// attempt to remove any file we created in the work dir
|
|
||||||
let _ = std::fs::remove_file(path);
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
fn event_loop(_i: usize, device: &Lock<Device>) {
|
|
||||||
#[cfg(target_os = "linux")]
|
|
||||||
let mut thread_local = ThreadData {
|
|
||||||
src_buf: [0u8; MAX_UDP_SIZE],
|
|
||||||
dst_buf: [0u8; MAX_UDP_SIZE],
|
|
||||||
iface: if _i == 0 || !device.read().config.use_multi_queue {
|
|
||||||
// For the first thread use the original iface
|
|
||||||
Arc::clone(&device.read().iface)
|
|
||||||
} else {
|
|
||||||
// For for the rest create a new iface queue
|
|
||||||
let iface_local = Arc::new(
|
|
||||||
TunSocket::new(&device.read().iface.name().unwrap())
|
|
||||||
.unwrap()
|
|
||||||
.set_non_blocking()
|
|
||||||
.unwrap(),
|
|
||||||
);
|
|
||||||
|
|
||||||
device
|
|
||||||
.read()
|
|
||||||
.register_iface_handler(Arc::clone(&iface_local))
|
|
||||||
.ok();
|
|
||||||
|
|
||||||
iface_local
|
|
||||||
},
|
|
||||||
};
|
|
||||||
|
|
||||||
#[cfg(not(target_os = "linux"))]
|
|
||||||
let mut thread_local = ThreadData {
|
|
||||||
src_buf: [0u8; MAX_UDP_SIZE],
|
|
||||||
dst_buf: [0u8; MAX_UDP_SIZE],
|
|
||||||
iface: Arc::clone(&device.read().iface),
|
|
||||||
};
|
|
||||||
|
|
||||||
#[cfg(not(target_os = "linux"))]
|
|
||||||
let uapi_fd = -1;
|
|
||||||
#[cfg(target_os = "linux")]
|
|
||||||
let uapi_fd = device.read().uapi_fd;
|
|
||||||
|
|
||||||
loop {
|
|
||||||
// The event loop keeps a read lock on the device, because we assume write access is rarely needed
|
|
||||||
let mut device_lock = device.read();
|
|
||||||
let queue = Arc::clone(&device_lock.queue);
|
|
||||||
|
|
||||||
loop {
|
|
||||||
match queue.wait() {
|
|
||||||
WaitResult::Ok(handler) => {
|
|
||||||
let action = (*handler)(&mut device_lock, &mut thread_local);
|
|
||||||
match action {
|
|
||||||
Action::Continue => {}
|
|
||||||
Action::Yield => break,
|
|
||||||
Action::Exit => {
|
|
||||||
device_lock.trigger_exit();
|
|
||||||
return;
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
WaitResult::EoF(handler) => {
|
|
||||||
if uapi_fd >= 0 && uapi_fd == handler.fd() {
|
|
||||||
device_lock.trigger_exit();
|
|
||||||
return;
|
|
||||||
}
|
|
||||||
handler.cancel();
|
|
||||||
}
|
|
||||||
WaitResult::Error(e) => tracing::error!(message = "Poll error", error = ?e),
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
impl Drop for DeviceHandle {
|
|
||||||
fn drop(&mut self) {
|
|
||||||
self.device.read().trigger_exit();
|
|
||||||
self.clean();
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
impl Device {
|
|
||||||
fn next_index(&mut self) -> u32 {
|
|
||||||
self.next_index.next()
|
|
||||||
}
|
|
||||||
|
|
||||||
fn remove_peer(&mut self, pub_key: &x25519::PublicKey) {
|
|
||||||
if let Some(peer) = self.peers.remove(pub_key) {
|
|
||||||
// Found a peer to remove, now purge all references to it:
|
|
||||||
{
|
|
||||||
let p = peer.lock();
|
|
||||||
p.shutdown_endpoint(); // close open udp socket and free the closure
|
|
||||||
self.peers_by_idx.remove(&p.index());
|
|
||||||
}
|
|
||||||
self.peers_by_ip
|
|
||||||
.remove(&|p: &Arc<Mutex<Peer>>| Arc::ptr_eq(&peer, p));
|
|
||||||
|
|
||||||
tracing::info!("Peer removed");
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
#[allow(clippy::too_many_arguments)]
|
|
||||||
fn update_peer(
|
|
||||||
&mut self,
|
|
||||||
pub_key: x25519::PublicKey,
|
|
||||||
remove: bool,
|
|
||||||
_replace_ips: bool,
|
|
||||||
endpoint: Option<SocketAddr>,
|
|
||||||
allowed_ips: &[AllowedIP],
|
|
||||||
keepalive: Option<u16>,
|
|
||||||
preshared_key: Option<[u8; 32]>,
|
|
||||||
) {
|
|
||||||
if remove {
|
|
||||||
// Completely remove a peer
|
|
||||||
return self.remove_peer(&pub_key);
|
|
||||||
}
|
|
||||||
|
|
||||||
// Update an existing peer
|
|
||||||
if self.peers.get(&pub_key).is_some() {
|
|
||||||
// We already have a peer, we need to merge the existing config into the newly created one
|
|
||||||
panic!("Modifying existing peers is not yet supported. Remove and add again instead.");
|
|
||||||
}
|
|
||||||
|
|
||||||
let next_index = self.next_index();
|
|
||||||
let device_key_pair = self
|
|
||||||
.key_pair
|
|
||||||
.as_ref()
|
|
||||||
.expect("Private key must be set first");
|
|
||||||
|
|
||||||
let tunn = Tunn::new(
|
|
||||||
device_key_pair.0.clone(),
|
|
||||||
pub_key,
|
|
||||||
preshared_key,
|
|
||||||
keepalive,
|
|
||||||
next_index,
|
|
||||||
None,
|
|
||||||
)
|
|
||||||
.unwrap();
|
|
||||||
|
|
||||||
let peer = Peer::new(tunn, next_index, endpoint, allowed_ips, preshared_key);
|
|
||||||
|
|
||||||
let peer = Arc::new(Mutex::new(peer));
|
|
||||||
self.peers.insert(pub_key, Arc::clone(&peer));
|
|
||||||
self.peers_by_idx.insert(next_index, Arc::clone(&peer));
|
|
||||||
|
|
||||||
for AllowedIP { addr, cidr } in allowed_ips {
|
|
||||||
self.peers_by_ip
|
|
||||||
.insert(*addr, *cidr as _, Arc::clone(&peer));
|
|
||||||
}
|
|
||||||
|
|
||||||
tracing::info!("Peer added");
|
|
||||||
}
|
|
||||||
|
|
||||||
pub fn new(name: &str, config: DeviceConfig) -> Result<Device, Error> {
|
|
||||||
let poll = EventPoll::<Handler>::new()?;
|
|
||||||
|
|
||||||
// Create a tunnel device
|
|
||||||
let iface = Arc::new(TunSocket::new(name)?.set_non_blocking()?);
|
|
||||||
let mtu = iface.mtu()?;
|
|
||||||
|
|
||||||
#[cfg(not(target_os = "linux"))]
|
|
||||||
let uapi_fd = -1;
|
|
||||||
#[cfg(target_os = "linux")]
|
|
||||||
let uapi_fd = config.uapi_fd;
|
|
||||||
|
|
||||||
let mut device = Device {
|
|
||||||
queue: Arc::new(poll),
|
|
||||||
iface,
|
|
||||||
config,
|
|
||||||
exit_notice: Default::default(),
|
|
||||||
yield_notice: Default::default(),
|
|
||||||
fwmark: Default::default(),
|
|
||||||
key_pair: Default::default(),
|
|
||||||
listen_port: Default::default(),
|
|
||||||
next_index: Default::default(),
|
|
||||||
peers: Default::default(),
|
|
||||||
peers_by_idx: Default::default(),
|
|
||||||
peers_by_ip: AllowedIps::new(),
|
|
||||||
udp4: Default::default(),
|
|
||||||
udp6: Default::default(),
|
|
||||||
cleanup_paths: Default::default(),
|
|
||||||
mtu: AtomicUsize::new(mtu),
|
|
||||||
rate_limiter: None,
|
|
||||||
#[cfg(target_os = "linux")]
|
|
||||||
uapi_fd,
|
|
||||||
};
|
|
||||||
|
|
||||||
if uapi_fd >= 0 {
|
|
||||||
device.register_api_fd(uapi_fd)?;
|
|
||||||
} else {
|
|
||||||
device.register_api_handler()?;
|
|
||||||
}
|
|
||||||
device.register_iface_handler(Arc::clone(&device.iface))?;
|
|
||||||
device.register_notifiers()?;
|
|
||||||
device.register_timers()?;
|
|
||||||
|
|
||||||
#[cfg(target_os = "macos")]
|
|
||||||
{
|
|
||||||
// Only for macOS write the actual socket name into WG_TUN_NAME_FILE
|
|
||||||
if let Ok(name_file) = std::env::var("WG_TUN_NAME_FILE") {
|
|
||||||
if name == "utun" {
|
|
||||||
std::fs::write(&name_file, device.iface.name().unwrap().as_bytes()).unwrap();
|
|
||||||
device.cleanup_paths.push(name_file);
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
Ok(device)
|
|
||||||
}
|
|
||||||
|
|
||||||
fn open_listen_socket(&mut self, mut port: u16) -> Result<(), Error> {
|
|
||||||
// Binds the network facing interfaces
|
|
||||||
// First close any existing open socket, and remove them from the event loop
|
|
||||||
if let Some(s) = self.udp4.take() {
|
|
||||||
unsafe {
|
|
||||||
// This is safe because the event loop is not running yet
|
|
||||||
self.queue.clear_event_by_fd(s.as_raw_fd())
|
|
||||||
}
|
|
||||||
};
|
|
||||||
|
|
||||||
if let Some(s) = self.udp6.take() {
|
|
||||||
unsafe { self.queue.clear_event_by_fd(s.as_raw_fd()) };
|
|
||||||
}
|
|
||||||
|
|
||||||
for peer in self.peers.values() {
|
|
||||||
peer.lock().shutdown_endpoint();
|
|
||||||
}
|
|
||||||
|
|
||||||
// Then open new sockets and bind to the port
|
|
||||||
let udp_sock4 = socket2::Socket::new(Domain::IPV4, Type::DGRAM, Some(Protocol::UDP))?;
|
|
||||||
udp_sock4.set_reuse_address(true)?;
|
|
||||||
udp_sock4.bind(&SocketAddrV4::new(Ipv4Addr::UNSPECIFIED, port).into())?;
|
|
||||||
udp_sock4.set_nonblocking(true)?;
|
|
||||||
|
|
||||||
if port == 0 {
|
|
||||||
// Random port was assigned
|
|
||||||
port = udp_sock4.local_addr()?.as_socket().unwrap().port();
|
|
||||||
}
|
|
||||||
|
|
||||||
let udp_sock6 = socket2::Socket::new(Domain::IPV6, Type::DGRAM, Some(Protocol::UDP))?;
|
|
||||||
udp_sock6.set_reuse_address(true)?;
|
|
||||||
udp_sock6.bind(&SocketAddrV6::new(Ipv6Addr::UNSPECIFIED, port, 0, 0).into())?;
|
|
||||||
udp_sock6.set_nonblocking(true)?;
|
|
||||||
|
|
||||||
self.register_udp_handler(udp_sock4.try_clone().unwrap())?;
|
|
||||||
self.register_udp_handler(udp_sock6.try_clone().unwrap())?;
|
|
||||||
self.udp4 = Some(udp_sock4);
|
|
||||||
self.udp6 = Some(udp_sock6);
|
|
||||||
|
|
||||||
self.listen_port = port;
|
|
||||||
|
|
||||||
Ok(())
|
|
||||||
}
|
|
||||||
|
|
||||||
fn set_key(&mut self, private_key: x25519::StaticSecret) {
|
|
||||||
let mut bad_peers = vec![];
|
|
||||||
|
|
||||||
let public_key = x25519::PublicKey::from(&private_key);
|
|
||||||
let key_pair = Some((private_key.clone(), public_key));
|
|
||||||
|
|
||||||
// x25519 (rightly) doesn't let us expose secret keys for comparison.
|
|
||||||
// If the public keys are the same, then the private keys are the same.
|
|
||||||
if Some(&public_key) == self.key_pair.as_ref().map(|p| &p.1) {
|
|
||||||
return;
|
|
||||||
}
|
|
||||||
|
|
||||||
let rate_limiter = Arc::new(RateLimiter::new(&public_key, HANDSHAKE_RATE_LIMIT));
|
|
||||||
|
|
||||||
for peer in self.peers.values_mut() {
|
|
||||||
let mut peer_mut = peer.lock();
|
|
||||||
|
|
||||||
if peer_mut
|
|
||||||
.tunnel
|
|
||||||
.set_static_private(
|
|
||||||
private_key.clone(),
|
|
||||||
public_key,
|
|
||||||
Some(Arc::clone(&rate_limiter)),
|
|
||||||
)
|
|
||||||
.is_err()
|
|
||||||
{
|
|
||||||
// In case we encounter an error, we will remove that peer
|
|
||||||
// An error will be a result of bad public key/secret key combination
|
|
||||||
bad_peers.push(Arc::clone(peer));
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
self.key_pair = key_pair;
|
|
||||||
self.rate_limiter = Some(rate_limiter);
|
|
||||||
|
|
||||||
// Remove all the bad peers
|
|
||||||
for _ in bad_peers {
|
|
||||||
unimplemented!();
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
#[cfg(any(target_os = "android", target_os = "fuchsia", target_os = "linux"))]
|
|
||||||
fn set_fwmark(&mut self, mark: u32) -> Result<(), Error> {
|
|
||||||
self.fwmark = Some(mark);
|
|
||||||
|
|
||||||
// First set fwmark on listeners
|
|
||||||
if let Some(ref sock) = self.udp4 {
|
|
||||||
sock.set_mark(mark)?;
|
|
||||||
}
|
|
||||||
|
|
||||||
if let Some(ref sock) = self.udp6 {
|
|
||||||
sock.set_mark(mark)?;
|
|
||||||
}
|
|
||||||
|
|
||||||
// Then on all currently connected sockets
|
|
||||||
for peer in self.peers.values() {
|
|
||||||
if let Some(ref sock) = peer.lock().endpoint().conn {
|
|
||||||
sock.set_mark(mark)?
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
Ok(())
|
|
||||||
}
|
|
||||||
|
|
||||||
fn clear_peers(&mut self) {
|
|
||||||
self.peers.clear();
|
|
||||||
self.peers_by_idx.clear();
|
|
||||||
self.peers_by_ip.clear();
|
|
||||||
}
|
|
||||||
|
|
||||||
fn register_notifiers(&mut self) -> Result<(), Error> {
|
|
||||||
let yield_ev = self
|
|
||||||
.queue
|
|
||||||
// The notification event handler simply returns Action::Yield
|
|
||||||
.new_notifier(Box::new(|_, _| Action::Yield))?;
|
|
||||||
self.yield_notice = Some(yield_ev);
|
|
||||||
|
|
||||||
let exit_ev = self
|
|
||||||
.queue
|
|
||||||
// The exit event handler simply returns Action::Exit
|
|
||||||
.new_notifier(Box::new(|_, _| Action::Exit))?;
|
|
||||||
self.exit_notice = Some(exit_ev);
|
|
||||||
Ok(())
|
|
||||||
}
|
|
||||||
|
|
||||||
fn register_timers(&self) -> Result<(), Error> {
|
|
||||||
self.queue.new_periodic_event(
|
|
||||||
// Reset the rate limiter every second give or take
|
|
||||||
Box::new(|d, _| {
|
|
||||||
if let Some(r) = d.rate_limiter.as_ref() {
|
|
||||||
r.reset_count()
|
|
||||||
}
|
|
||||||
Action::Continue
|
|
||||||
}),
|
|
||||||
std::time::Duration::from_secs(1),
|
|
||||||
)?;
|
|
||||||
|
|
||||||
self.queue.new_periodic_event(
|
|
||||||
// Execute the timed function of every peer in the list
|
|
||||||
Box::new(|d, t| {
|
|
||||||
let peer_map = &d.peers;
|
|
||||||
|
|
||||||
let (udp4, udp6) = match (d.udp4.as_ref(), d.udp6.as_ref()) {
|
|
||||||
(Some(udp4), Some(udp6)) => (udp4, udp6),
|
|
||||||
_ => return Action::Continue,
|
|
||||||
};
|
|
||||||
|
|
||||||
// Go over each peer and invoke the timer function
|
|
||||||
for peer in peer_map.values() {
|
|
||||||
let mut p = peer.lock();
|
|
||||||
let endpoint_addr = match p.endpoint().addr {
|
|
||||||
Some(addr) => addr,
|
|
||||||
None => continue,
|
|
||||||
};
|
|
||||||
|
|
||||||
match p.update_timers(&mut t.dst_buf[..]) {
|
|
||||||
TunnResult::Done => {}
|
|
||||||
TunnResult::Err(WireGuardError::ConnectionExpired) => {
|
|
||||||
p.shutdown_endpoint(); // close open udp socket
|
|
||||||
}
|
|
||||||
TunnResult::Err(e) => tracing::error!(message = "Timer error", error = ?e),
|
|
||||||
TunnResult::WriteToNetwork(packet) => {
|
|
||||||
match endpoint_addr {
|
|
||||||
SocketAddr::V4(_) => {
|
|
||||||
udp4.send_to(packet, &endpoint_addr.into()).ok()
|
|
||||||
}
|
|
||||||
SocketAddr::V6(_) => {
|
|
||||||
udp6.send_to(packet, &endpoint_addr.into()).ok()
|
|
||||||
}
|
|
||||||
};
|
|
||||||
}
|
|
||||||
_ => panic!("Unexpected result from update_timers"),
|
|
||||||
};
|
|
||||||
}
|
|
||||||
Action::Continue
|
|
||||||
}),
|
|
||||||
std::time::Duration::from_millis(250),
|
|
||||||
)?;
|
|
||||||
Ok(())
|
|
||||||
}
|
|
||||||
|
|
||||||
pub(crate) fn trigger_yield(&self) {
|
|
||||||
self.queue
|
|
||||||
.trigger_notification(self.yield_notice.as_ref().unwrap())
|
|
||||||
}
|
|
||||||
|
|
||||||
pub(crate) fn trigger_exit(&self) {
|
|
||||||
self.queue
|
|
||||||
.trigger_notification(self.exit_notice.as_ref().unwrap())
|
|
||||||
}
|
|
||||||
|
|
||||||
pub(crate) fn cancel_yield(&self) {
|
|
||||||
self.queue
|
|
||||||
.stop_notification(self.yield_notice.as_ref().unwrap())
|
|
||||||
}
|
|
||||||
|
|
||||||
fn register_udp_handler(&self, udp: socket2::Socket) -> Result<(), Error> {
|
|
||||||
self.queue.new_event(
|
|
||||||
udp.as_raw_fd(),
|
|
||||||
Box::new(move |d, t| {
|
|
||||||
// Handler that handles anonymous packets over UDP
|
|
||||||
let mut iter = MAX_ITR;
|
|
||||||
let (private_key, public_key) = d.key_pair.as_ref().expect("Key not set");
|
|
||||||
|
|
||||||
let rate_limiter = d.rate_limiter.as_ref().unwrap();
|
|
||||||
|
|
||||||
// Loop while we have packets on the anonymous connection
|
|
||||||
|
|
||||||
// Safety: the `recv_from` implementation promises not to write uninitialised
|
|
||||||
// bytes to the buffer, so this casting is safe.
|
|
||||||
let src_buf =
|
|
||||||
unsafe { &mut *(&mut t.src_buf[..] as *mut [u8] as *mut [MaybeUninit<u8>]) };
|
|
||||||
while let Ok((packet_len, addr)) = udp.recv_from(src_buf) {
|
|
||||||
let packet = &t.src_buf[..packet_len];
|
|
||||||
// The rate limiter initially checks mac1 and mac2, and optionally asks to send a cookie
|
|
||||||
let parsed_packet = match rate_limiter.verify_packet(
|
|
||||||
Some(addr.as_socket().unwrap().ip()),
|
|
||||||
packet,
|
|
||||||
&mut t.dst_buf,
|
|
||||||
) {
|
|
||||||
Ok(packet) => packet,
|
|
||||||
Err(TunnResult::WriteToNetwork(cookie)) => {
|
|
||||||
let _: Result<_, _> = udp.send_to(cookie, &addr);
|
|
||||||
continue;
|
|
||||||
}
|
|
||||||
Err(_) => continue,
|
|
||||||
};
|
|
||||||
|
|
||||||
let peer = match &parsed_packet {
|
|
||||||
Packet::HandshakeInit(p) => {
|
|
||||||
parse_handshake_anon(private_key, public_key, p)
|
|
||||||
.ok()
|
|
||||||
.and_then(|hh| {
|
|
||||||
d.peers.get(&x25519::PublicKey::from(hh.peer_static_public))
|
|
||||||
})
|
|
||||||
}
|
|
||||||
Packet::HandshakeResponse(p) => d.peers_by_idx.get(&(p.receiver_idx >> 8)),
|
|
||||||
Packet::PacketCookieReply(p) => d.peers_by_idx.get(&(p.receiver_idx >> 8)),
|
|
||||||
Packet::PacketData(p) => d.peers_by_idx.get(&(p.receiver_idx >> 8)),
|
|
||||||
};
|
|
||||||
|
|
||||||
let peer = match peer {
|
|
||||||
None => continue,
|
|
||||||
Some(peer) => peer,
|
|
||||||
};
|
|
||||||
|
|
||||||
let mut p = peer.lock();
|
|
||||||
|
|
||||||
// We found a peer, use it to decapsulate the message+
|
|
||||||
let mut flush = false; // Are there packets to send from the queue?
|
|
||||||
match p
|
|
||||||
.tunnel
|
|
||||||
.handle_verified_packet(parsed_packet, &mut t.dst_buf[..])
|
|
||||||
{
|
|
||||||
TunnResult::Done => {}
|
|
||||||
TunnResult::Err(_) => continue,
|
|
||||||
TunnResult::WriteToNetwork(packet) => {
|
|
||||||
flush = true;
|
|
||||||
let _: Result<_, _> = udp.send_to(packet, &addr);
|
|
||||||
}
|
|
||||||
TunnResult::WriteToTunnelV4(packet, addr) => {
|
|
||||||
if p.is_allowed_ip(addr) {
|
|
||||||
t.iface.write4(packet);
|
|
||||||
}
|
|
||||||
}
|
|
||||||
TunnResult::WriteToTunnelV6(packet, addr) => {
|
|
||||||
if p.is_allowed_ip(addr) {
|
|
||||||
t.iface.write6(packet);
|
|
||||||
}
|
|
||||||
}
|
|
||||||
};
|
|
||||||
|
|
||||||
if flush {
|
|
||||||
// Flush pending queue
|
|
||||||
while let TunnResult::WriteToNetwork(packet) =
|
|
||||||
p.tunnel.decapsulate(None, &[], &mut t.dst_buf[..])
|
|
||||||
{
|
|
||||||
let _: Result<_, _> = udp.send_to(packet, &addr);
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
// This packet was OK, that means we want to create a connected socket for this peer
|
|
||||||
let addr = addr.as_socket().unwrap();
|
|
||||||
let ip_addr = addr.ip();
|
|
||||||
p.set_endpoint(addr);
|
|
||||||
if d.config.use_connected_socket {
|
|
||||||
if let Ok(sock) = p.connect_endpoint(d.listen_port, d.fwmark) {
|
|
||||||
d.register_conn_handler(Arc::clone(peer), sock, ip_addr)
|
|
||||||
.unwrap();
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
iter -= 1;
|
|
||||||
if iter == 0 {
|
|
||||||
break;
|
|
||||||
}
|
|
||||||
}
|
|
||||||
Action::Continue
|
|
||||||
}),
|
|
||||||
)?;
|
|
||||||
Ok(())
|
|
||||||
}
|
|
||||||
|
|
||||||
fn register_conn_handler(
|
|
||||||
&self,
|
|
||||||
peer: Arc<Mutex<Peer>>,
|
|
||||||
udp: socket2::Socket,
|
|
||||||
peer_addr: IpAddr,
|
|
||||||
) -> Result<(), Error> {
|
|
||||||
self.queue.new_event(
|
|
||||||
udp.as_raw_fd(),
|
|
||||||
Box::new(move |_, t| {
|
|
||||||
// The conn_handler handles packet received from a connected UDP socket, associated
|
|
||||||
// with a known peer, this saves us the hustle of finding the right peer. If another
|
|
||||||
// peer gets the same ip, it will be ignored until the socket does not expire.
|
|
||||||
let iface = &t.iface;
|
|
||||||
let mut iter = MAX_ITR;
|
|
||||||
|
|
||||||
// Safety: the `recv_from` implementation promises not to write uninitialised
|
|
||||||
// bytes to the buffer, so this casting is safe.
|
|
||||||
let src_buf =
|
|
||||||
unsafe { &mut *(&mut t.src_buf[..] as *mut [u8] as *mut [MaybeUninit<u8>]) };
|
|
||||||
|
|
||||||
while let Ok(read_bytes) = udp.recv(src_buf) {
|
|
||||||
let mut flush = false;
|
|
||||||
let mut p = peer.lock();
|
|
||||||
match p.tunnel.decapsulate(
|
|
||||||
Some(peer_addr),
|
|
||||||
&t.src_buf[..read_bytes],
|
|
||||||
&mut t.dst_buf[..],
|
|
||||||
) {
|
|
||||||
TunnResult::Done => {}
|
|
||||||
TunnResult::Err(e) => eprintln!("Decapsulate error {:?}", e),
|
|
||||||
TunnResult::WriteToNetwork(packet) => {
|
|
||||||
flush = true;
|
|
||||||
let _: Result<_, _> = udp.send(packet);
|
|
||||||
}
|
|
||||||
TunnResult::WriteToTunnelV4(packet, addr) => {
|
|
||||||
if p.is_allowed_ip(addr) {
|
|
||||||
iface.write4(packet);
|
|
||||||
}
|
|
||||||
}
|
|
||||||
TunnResult::WriteToTunnelV6(packet, addr) => {
|
|
||||||
if p.is_allowed_ip(addr) {
|
|
||||||
iface.write6(packet);
|
|
||||||
}
|
|
||||||
}
|
|
||||||
};
|
|
||||||
|
|
||||||
if flush {
|
|
||||||
// Flush pending queue
|
|
||||||
while let TunnResult::WriteToNetwork(packet) =
|
|
||||||
p.tunnel.decapsulate(None, &[], &mut t.dst_buf[..])
|
|
||||||
{
|
|
||||||
let _: Result<_, _> = udp.send(packet);
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
iter -= 1;
|
|
||||||
if iter == 0 {
|
|
||||||
break;
|
|
||||||
}
|
|
||||||
}
|
|
||||||
Action::Continue
|
|
||||||
}),
|
|
||||||
)?;
|
|
||||||
Ok(())
|
|
||||||
}
|
|
||||||
|
|
||||||
fn register_iface_handler(&self, iface: Arc<TunSocket>) -> Result<(), Error> {
|
|
||||||
self.queue.new_event(
|
|
||||||
iface.as_raw_fd(),
|
|
||||||
Box::new(move |d, t| {
|
|
||||||
// The iface_handler handles packets received from the WireGuard virtual network
|
|
||||||
// interface. The flow is as follows:
|
|
||||||
// * Read a packet
|
|
||||||
// * Determine peer based on packet destination ip
|
|
||||||
// * Encapsulate the packet for the given peer
|
|
||||||
// * Send encapsulated packet to the peer's endpoint
|
|
||||||
let mtu = d.mtu.load(Ordering::Relaxed);
|
|
||||||
|
|
||||||
let udp4 = d.udp4.as_ref().expect("Not connected");
|
|
||||||
let udp6 = d.udp6.as_ref().expect("Not connected");
|
|
||||||
|
|
||||||
let peers = &d.peers_by_ip;
|
|
||||||
for _ in 0..MAX_ITR {
|
|
||||||
let src = match iface.read(&mut t.src_buf[..mtu]) {
|
|
||||||
Ok(src) => src,
|
|
||||||
Err(Error::IfaceRead(e)) => {
|
|
||||||
let ek = e.kind();
|
|
||||||
if ek == io::ErrorKind::Interrupted || ek == io::ErrorKind::WouldBlock {
|
|
||||||
break;
|
|
||||||
}
|
|
||||||
eprintln!("Fatal read error on tun interface: {:?}", e);
|
|
||||||
return Action::Exit;
|
|
||||||
}
|
|
||||||
Err(e) => {
|
|
||||||
eprintln!("Unexpected error on tun interface: {:?}", e);
|
|
||||||
return Action::Exit;
|
|
||||||
}
|
|
||||||
};
|
|
||||||
|
|
||||||
let dst_addr = match Tunn::dst_address(src) {
|
|
||||||
Some(addr) => addr,
|
|
||||||
None => continue,
|
|
||||||
};
|
|
||||||
|
|
||||||
let mut peer = match peers.find(dst_addr) {
|
|
||||||
Some(peer) => peer.lock(),
|
|
||||||
None => continue,
|
|
||||||
};
|
|
||||||
|
|
||||||
match peer.tunnel.encapsulate(src, &mut t.dst_buf[..]) {
|
|
||||||
TunnResult::Done => {}
|
|
||||||
TunnResult::Err(e) => {
|
|
||||||
tracing::error!(message = "Encapsulate error", error = ?e)
|
|
||||||
}
|
|
||||||
TunnResult::WriteToNetwork(packet) => {
|
|
||||||
let mut endpoint = peer.endpoint_mut();
|
|
||||||
if let Some(conn) = endpoint.conn.as_mut() {
|
|
||||||
// Prefer to send using the connected socket
|
|
||||||
let _: Result<_, _> = conn.write(packet);
|
|
||||||
} else if let Some(addr @ SocketAddr::V4(_)) = endpoint.addr {
|
|
||||||
let _: Result<_, _> = udp4.send_to(packet, &addr.into());
|
|
||||||
} else if let Some(addr @ SocketAddr::V6(_)) = endpoint.addr {
|
|
||||||
let _: Result<_, _> = udp6.send_to(packet, &addr.into());
|
|
||||||
} else {
|
|
||||||
tracing::error!("No endpoint");
|
|
||||||
}
|
|
||||||
}
|
|
||||||
_ => panic!("Unexpected result from encapsulate"),
|
|
||||||
};
|
|
||||||
}
|
|
||||||
Action::Continue
|
|
||||||
}),
|
|
||||||
)?;
|
|
||||||
Ok(())
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
/// A basic linear-feedback shift register implemented as xorshift, used to
|
|
||||||
/// distribute peer indexes across the 24-bit address space reserved for peer
|
|
||||||
/// identification.
|
|
||||||
/// The purpose is to obscure the total number of peers using the system and to
|
|
||||||
/// ensure it requires a non-trivial amount of processing power and/or samples
|
|
||||||
/// to guess other peers' indices. Anything more ambitious than this is wasted
|
|
||||||
/// with only 24 bits of space.
|
|
||||||
struct IndexLfsr {
|
|
||||||
initial: u32,
|
|
||||||
lfsr: u32,
|
|
||||||
mask: u32,
|
|
||||||
}
|
|
||||||
|
|
||||||
impl IndexLfsr {
|
|
||||||
/// Generate a random 24-bit nonzero integer
|
|
||||||
fn random_index() -> u32 {
|
|
||||||
const LFSR_MAX: u32 = 0xffffff; // 24-bit seed
|
|
||||||
loop {
|
|
||||||
let i = OsRng.next_u32() & LFSR_MAX;
|
|
||||||
if i > 0 {
|
|
||||||
// LFSR seed must be non-zero
|
|
||||||
return i;
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
/// Generate the next value in the pseudorandom sequence
|
|
||||||
fn next(&mut self) -> u32 {
|
|
||||||
// 24-bit polynomial for randomness. This is arbitrarily chosen to
|
|
||||||
// inject bitflips into the value.
|
|
||||||
const LFSR_POLY: u32 = 0xd80000; // 24-bit polynomial
|
|
||||||
let value = self.lfsr - 1; // lfsr will never have value of 0
|
|
||||||
self.lfsr = (self.lfsr >> 1) ^ ((0u32.wrapping_sub(self.lfsr & 1u32)) & LFSR_POLY);
|
|
||||||
assert!(self.lfsr != self.initial, "Too many peers created");
|
|
||||||
value ^ self.mask
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
impl Default for IndexLfsr {
|
|
||||||
fn default() -> Self {
|
|
||||||
let seed = Self::random_index();
|
|
||||||
IndexLfsr {
|
|
||||||
initial: seed,
|
|
||||||
lfsr: seed,
|
|
||||||
mask: Self::random_index(),
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
@ -1,170 +0,0 @@
|
||||||
// Copyright (c) 2019 Cloudflare, Inc. All rights reserved.
|
|
||||||
// SPDX-License-Identifier: BSD-3-Clause
|
|
||||||
|
|
||||||
use parking_lot::RwLock;
|
|
||||||
use socket2::{Domain, Protocol, Type};
|
|
||||||
|
|
||||||
use std::net::{IpAddr, Ipv4Addr, Ipv6Addr, Shutdown, SocketAddr, SocketAddrV4, SocketAddrV6};
|
|
||||||
use std::str::FromStr;
|
|
||||||
|
|
||||||
use crate::device::{AllowedIps, Error};
|
|
||||||
use crate::noise::{Tunn, TunnResult};
|
|
||||||
|
|
||||||
#[derive(Default, Debug)]
|
|
||||||
pub struct Endpoint {
|
|
||||||
pub addr: Option<SocketAddr>,
|
|
||||||
pub conn: Option<socket2::Socket>,
|
|
||||||
}
|
|
||||||
|
|
||||||
pub struct Peer {
|
|
||||||
/// The associated tunnel struct
|
|
||||||
pub(crate) tunnel: Tunn,
|
|
||||||
/// The index the tunnel uses
|
|
||||||
index: u32,
|
|
||||||
endpoint: RwLock<Endpoint>,
|
|
||||||
allowed_ips: AllowedIps<()>,
|
|
||||||
preshared_key: Option<[u8; 32]>,
|
|
||||||
}
|
|
||||||
|
|
||||||
#[derive(Copy, Clone, Ord, PartialOrd, Eq, PartialEq, Hash, Debug)]
|
|
||||||
pub struct AllowedIP {
|
|
||||||
pub addr: IpAddr,
|
|
||||||
pub cidr: u8,
|
|
||||||
}
|
|
||||||
|
|
||||||
impl FromStr for AllowedIP {
|
|
||||||
type Err = String;
|
|
||||||
|
|
||||||
fn from_str(s: &str) -> Result<Self, Self::Err> {
|
|
||||||
let ip: Vec<&str> = s.split('/').collect();
|
|
||||||
if ip.len() != 2 {
|
|
||||||
return Err("Invalid IP format".to_owned());
|
|
||||||
}
|
|
||||||
|
|
||||||
let (addr, cidr) = (ip[0].parse::<IpAddr>(), ip[1].parse::<u8>());
|
|
||||||
match (addr, cidr) {
|
|
||||||
(Ok(addr @ IpAddr::V4(_)), Ok(cidr)) if cidr <= 32 => Ok(AllowedIP { addr, cidr }),
|
|
||||||
(Ok(addr @ IpAddr::V6(_)), Ok(cidr)) if cidr <= 128 => Ok(AllowedIP { addr, cidr }),
|
|
||||||
_ => Err("Invalid IP format".to_owned()),
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
impl Peer {
|
|
||||||
pub fn new(
|
|
||||||
tunnel: Tunn,
|
|
||||||
index: u32,
|
|
||||||
endpoint: Option<SocketAddr>,
|
|
||||||
allowed_ips: &[AllowedIP],
|
|
||||||
preshared_key: Option<[u8; 32]>,
|
|
||||||
) -> Peer {
|
|
||||||
Peer {
|
|
||||||
tunnel,
|
|
||||||
index,
|
|
||||||
endpoint: RwLock::new(Endpoint {
|
|
||||||
addr: endpoint,
|
|
||||||
conn: None,
|
|
||||||
}),
|
|
||||||
allowed_ips: allowed_ips.iter().map(|ip| (ip, ())).collect(),
|
|
||||||
preshared_key,
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
pub fn update_timers<'a>(&mut self, dst: &'a mut [u8]) -> TunnResult<'a> {
|
|
||||||
self.tunnel.update_timers(dst)
|
|
||||||
}
|
|
||||||
|
|
||||||
pub fn endpoint(&self) -> parking_lot::RwLockReadGuard<'_, Endpoint> {
|
|
||||||
self.endpoint.read()
|
|
||||||
}
|
|
||||||
|
|
||||||
pub(crate) fn endpoint_mut(&self) -> parking_lot::RwLockWriteGuard<'_, Endpoint> {
|
|
||||||
self.endpoint.write()
|
|
||||||
}
|
|
||||||
|
|
||||||
pub fn shutdown_endpoint(&self) {
|
|
||||||
if let Some(conn) = self.endpoint.write().conn.take() {
|
|
||||||
tracing::info!("Disconnecting from endpoint");
|
|
||||||
conn.shutdown(Shutdown::Both).unwrap();
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
pub fn set_endpoint(&self, addr: SocketAddr) {
|
|
||||||
let mut endpoint = self.endpoint.write();
|
|
||||||
if endpoint.addr != Some(addr) {
|
|
||||||
// We only need to update the endpoint if it differs from the current one
|
|
||||||
if let Some(conn) = endpoint.conn.take() {
|
|
||||||
conn.shutdown(Shutdown::Both).unwrap();
|
|
||||||
}
|
|
||||||
|
|
||||||
endpoint.addr = Some(addr);
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
pub fn connect_endpoint(
|
|
||||||
&self,
|
|
||||||
port: u16,
|
|
||||||
fwmark: Option<u32>,
|
|
||||||
) -> Result<socket2::Socket, Error> {
|
|
||||||
let mut endpoint = self.endpoint.write();
|
|
||||||
|
|
||||||
if endpoint.conn.is_some() {
|
|
||||||
return Err(Error::Connect("Connected".to_owned()));
|
|
||||||
}
|
|
||||||
|
|
||||||
let addr = endpoint
|
|
||||||
.addr
|
|
||||||
.expect("Attempt to connect to undefined endpoint");
|
|
||||||
|
|
||||||
let udp_conn =
|
|
||||||
socket2::Socket::new(Domain::for_address(addr), Type::STREAM, Some(Protocol::UDP))?;
|
|
||||||
udp_conn.set_reuse_address(true)?;
|
|
||||||
let bind_addr = if addr.is_ipv4() {
|
|
||||||
SocketAddrV4::new(Ipv4Addr::UNSPECIFIED, port).into()
|
|
||||||
} else {
|
|
||||||
SocketAddrV6::new(Ipv6Addr::UNSPECIFIED, port, 0, 0).into()
|
|
||||||
};
|
|
||||||
udp_conn.bind(&bind_addr)?;
|
|
||||||
udp_conn.connect(&addr.into())?;
|
|
||||||
udp_conn.set_nonblocking(true)?;
|
|
||||||
|
|
||||||
#[cfg(any(target_os = "android", target_os = "fuchsia", target_os = "linux"))]
|
|
||||||
if let Some(fwmark) = fwmark {
|
|
||||||
udp_conn.set_mark(fwmark)?;
|
|
||||||
}
|
|
||||||
|
|
||||||
tracing::info!(
|
|
||||||
message="Connected endpoint",
|
|
||||||
port=port,
|
|
||||||
endpoint=?endpoint.addr.unwrap()
|
|
||||||
);
|
|
||||||
|
|
||||||
endpoint.conn = Some(udp_conn.try_clone().unwrap());
|
|
||||||
|
|
||||||
Ok(udp_conn)
|
|
||||||
}
|
|
||||||
|
|
||||||
pub fn is_allowed_ip<I: Into<IpAddr>>(&self, addr: I) -> bool {
|
|
||||||
self.allowed_ips.find(addr.into()).is_some()
|
|
||||||
}
|
|
||||||
|
|
||||||
pub fn allowed_ips(&self) -> impl Iterator<Item = (IpAddr, u8)> + '_ {
|
|
||||||
self.allowed_ips.iter().map(|(_, ip, cidr)| (ip, cidr))
|
|
||||||
}
|
|
||||||
|
|
||||||
pub fn time_since_last_handshake(&self) -> Option<std::time::Duration> {
|
|
||||||
self.tunnel.time_since_last_handshake()
|
|
||||||
}
|
|
||||||
|
|
||||||
pub fn persistent_keepalive(&self) -> Option<u16> {
|
|
||||||
self.tunnel.persistent_keepalive()
|
|
||||||
}
|
|
||||||
|
|
||||||
pub fn preshared_key(&self) -> Option<&[u8; 32]> {
|
|
||||||
self.preshared_key.as_ref()
|
|
||||||
}
|
|
||||||
|
|
||||||
pub fn index(&self) -> u32 {
|
|
||||||
self.index
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
@ -1,27 +0,0 @@
|
||||||
// Copyright (c) 2019 Cloudflare, Inc. All rights reserved.
|
|
||||||
// SPDX-License-Identifier: BSD-3-Clause
|
|
||||||
|
|
||||||
//! Simple implementation of the client-side of the WireGuard protocol.
|
|
||||||
//!
|
|
||||||
//! <code>git clone https://github.com/cloudflare/boringtun.git</code>
|
|
||||||
|
|
||||||
#[cfg(feature = "device")]
|
|
||||||
pub mod device;
|
|
||||||
|
|
||||||
#[cfg(feature = "ffi-bindings")]
|
|
||||||
pub mod ffi;
|
|
||||||
#[cfg(feature = "jni-bindings")]
|
|
||||||
pub mod jni;
|
|
||||||
pub mod noise;
|
|
||||||
|
|
||||||
#[cfg(not(feature = "mock-instant"))]
|
|
||||||
pub(crate) mod sleepyinstant;
|
|
||||||
|
|
||||||
pub(crate) mod serialization;
|
|
||||||
|
|
||||||
/// Re-export of the x25519 types
|
|
||||||
pub mod x25519 {
|
|
||||||
pub use x25519_dalek::{
|
|
||||||
EphemeralSecret, PublicKey, ReusableSecret, SharedSecret, StaticSecret,
|
|
||||||
};
|
|
||||||
}
|
|
||||||
|
|
@ -1,23 +0,0 @@
|
||||||
// Copyright (c) 2019 Cloudflare, Inc. All rights reserved.
|
|
||||||
// SPDX-License-Identifier: BSD-3-Clause
|
|
||||||
|
|
||||||
#[derive(Debug)]
|
|
||||||
pub enum WireGuardError {
|
|
||||||
DestinationBufferTooSmall,
|
|
||||||
IncorrectPacketLength,
|
|
||||||
UnexpectedPacket,
|
|
||||||
WrongPacketType,
|
|
||||||
WrongIndex,
|
|
||||||
WrongKey,
|
|
||||||
InvalidTai64nTimestamp,
|
|
||||||
WrongTai64nTimestamp,
|
|
||||||
InvalidMac,
|
|
||||||
InvalidAeadTag,
|
|
||||||
InvalidCounter,
|
|
||||||
DuplicateCounter,
|
|
||||||
InvalidPacket,
|
|
||||||
NoCurrentSession,
|
|
||||||
LockFailed,
|
|
||||||
ConnectionExpired,
|
|
||||||
UnderLoad,
|
|
||||||
}
|
|
||||||
|
|
@ -1,941 +0,0 @@
|
||||||
// Copyright (c) 2019 Cloudflare, Inc. All rights reserved.
|
|
||||||
// SPDX-License-Identifier: BSD-3-Clause
|
|
||||||
|
|
||||||
use super::{HandshakeInit, HandshakeResponse, PacketCookieReply};
|
|
||||||
use crate::noise::errors::WireGuardError;
|
|
||||||
use crate::noise::session::Session;
|
|
||||||
#[cfg(not(feature = "mock-instant"))]
|
|
||||||
use crate::sleepyinstant::Instant;
|
|
||||||
use crate::x25519;
|
|
||||||
use aead::{Aead, Payload};
|
|
||||||
use blake2::digest::{FixedOutput, KeyInit};
|
|
||||||
use blake2::{Blake2s256, Blake2sMac, Digest};
|
|
||||||
use chacha20poly1305::XChaCha20Poly1305;
|
|
||||||
use rand_core::OsRng;
|
|
||||||
use ring::aead::{Aad, LessSafeKey, Nonce, UnboundKey, CHACHA20_POLY1305};
|
|
||||||
use std::convert::TryInto;
|
|
||||||
use std::time::{Duration, SystemTime};
|
|
||||||
|
|
||||||
#[cfg(feature = "mock-instant")]
|
|
||||||
use mock_instant::Instant;
|
|
||||||
|
|
||||||
pub(crate) const LABEL_MAC1: &[u8; 8] = b"mac1----";
|
|
||||||
pub(crate) const LABEL_COOKIE: &[u8; 8] = b"cookie--";
|
|
||||||
const KEY_LEN: usize = 32;
|
|
||||||
const TIMESTAMP_LEN: usize = 12;
|
|
||||||
|
|
||||||
// initiator.chaining_key = HASH(CONSTRUCTION)
|
|
||||||
const INITIAL_CHAIN_KEY: [u8; KEY_LEN] = [
|
|
||||||
96, 226, 109, 174, 243, 39, 239, 192, 46, 195, 53, 226, 160, 37, 210, 208, 22, 235, 66, 6, 248,
|
|
||||||
114, 119, 245, 45, 56, 209, 152, 139, 120, 205, 54,
|
|
||||||
];
|
|
||||||
|
|
||||||
// initiator.chaining_hash = HASH(initiator.chaining_key || IDENTIFIER)
|
|
||||||
const INITIAL_CHAIN_HASH: [u8; KEY_LEN] = [
|
|
||||||
34, 17, 179, 97, 8, 26, 197, 102, 105, 18, 67, 219, 69, 138, 213, 50, 45, 156, 108, 102, 34,
|
|
||||||
147, 232, 183, 14, 225, 156, 101, 186, 7, 158, 243,
|
|
||||||
];
|
|
||||||
|
|
||||||
#[inline]
|
|
||||||
pub(crate) fn b2s_hash(data1: &[u8], data2: &[u8]) -> [u8; 32] {
|
|
||||||
let mut hash = Blake2s256::new();
|
|
||||||
hash.update(data1);
|
|
||||||
hash.update(data2);
|
|
||||||
hash.finalize().into()
|
|
||||||
}
|
|
||||||
|
|
||||||
#[inline]
|
|
||||||
/// RFC 2401 HMAC+Blake2s, not to be confused with *keyed* Blake2s
|
|
||||||
pub(crate) fn b2s_hmac(key: &[u8], data1: &[u8]) -> [u8; 32] {
|
|
||||||
use blake2::digest::Update;
|
|
||||||
type HmacBlake2s = hmac::SimpleHmac<Blake2s256>;
|
|
||||||
let mut hmac = HmacBlake2s::new_from_slice(key).unwrap();
|
|
||||||
hmac.update(data1);
|
|
||||||
hmac.finalize_fixed().into()
|
|
||||||
}
|
|
||||||
|
|
||||||
#[inline]
|
|
||||||
/// Like b2s_hmac, but chain data1 and data2 together
|
|
||||||
pub(crate) fn b2s_hmac2(key: &[u8], data1: &[u8], data2: &[u8]) -> [u8; 32] {
|
|
||||||
use blake2::digest::Update;
|
|
||||||
type HmacBlake2s = hmac::SimpleHmac<Blake2s256>;
|
|
||||||
let mut hmac = HmacBlake2s::new_from_slice(key).unwrap();
|
|
||||||
hmac.update(data1);
|
|
||||||
hmac.update(data2);
|
|
||||||
hmac.finalize_fixed().into()
|
|
||||||
}
|
|
||||||
|
|
||||||
#[inline]
|
|
||||||
pub(crate) fn b2s_keyed_mac_16(key: &[u8], data1: &[u8]) -> [u8; 16] {
|
|
||||||
let mut hmac = Blake2sMac::new_from_slice(key).unwrap();
|
|
||||||
blake2::digest::Update::update(&mut hmac, data1);
|
|
||||||
hmac.finalize_fixed().into()
|
|
||||||
}
|
|
||||||
|
|
||||||
#[inline]
|
|
||||||
pub(crate) fn b2s_keyed_mac_16_2(key: &[u8], data1: &[u8], data2: &[u8]) -> [u8; 16] {
|
|
||||||
let mut hmac = Blake2sMac::new_from_slice(key).unwrap();
|
|
||||||
blake2::digest::Update::update(&mut hmac, data1);
|
|
||||||
blake2::digest::Update::update(&mut hmac, data2);
|
|
||||||
hmac.finalize_fixed().into()
|
|
||||||
}
|
|
||||||
|
|
||||||
pub(crate) fn b2s_mac_24(key: &[u8], data1: &[u8]) -> [u8; 24] {
|
|
||||||
let mut hmac = Blake2sMac::new_from_slice(key).unwrap();
|
|
||||||
blake2::digest::Update::update(&mut hmac, data1);
|
|
||||||
hmac.finalize_fixed().into()
|
|
||||||
}
|
|
||||||
|
|
||||||
#[inline]
|
|
||||||
/// This wrapper involves an extra copy and MAY BE SLOWER
|
|
||||||
fn aead_chacha20_seal(ciphertext: &mut [u8], key: &[u8], counter: u64, data: &[u8], aad: &[u8]) {
|
|
||||||
let mut nonce: [u8; 12] = [0; 12];
|
|
||||||
nonce[4..12].copy_from_slice(&counter.to_le_bytes());
|
|
||||||
|
|
||||||
aead_chacha20_seal_inner(ciphertext, key, nonce, data, aad)
|
|
||||||
}
|
|
||||||
|
|
||||||
#[inline]
|
|
||||||
fn aead_chacha20_seal_inner(
|
|
||||||
ciphertext: &mut [u8],
|
|
||||||
key: &[u8],
|
|
||||||
nonce: [u8; 12],
|
|
||||||
data: &[u8],
|
|
||||||
aad: &[u8],
|
|
||||||
) {
|
|
||||||
let key = LessSafeKey::new(UnboundKey::new(&CHACHA20_POLY1305, key).unwrap());
|
|
||||||
|
|
||||||
ciphertext[..data.len()].copy_from_slice(data);
|
|
||||||
|
|
||||||
let tag = key
|
|
||||||
.seal_in_place_separate_tag(
|
|
||||||
Nonce::assume_unique_for_key(nonce),
|
|
||||||
Aad::from(aad),
|
|
||||||
&mut ciphertext[..data.len()],
|
|
||||||
)
|
|
||||||
.unwrap();
|
|
||||||
|
|
||||||
ciphertext[data.len()..].copy_from_slice(tag.as_ref());
|
|
||||||
}
|
|
||||||
|
|
||||||
#[inline]
|
|
||||||
/// This wrapper involves an extra copy and MAY BE SLOWER
|
|
||||||
fn aead_chacha20_open(
|
|
||||||
buffer: &mut [u8],
|
|
||||||
key: &[u8],
|
|
||||||
counter: u64,
|
|
||||||
data: &[u8],
|
|
||||||
aad: &[u8],
|
|
||||||
) -> Result<(), WireGuardError> {
|
|
||||||
let mut nonce: [u8; 12] = [0; 12];
|
|
||||||
nonce[4..].copy_from_slice(&counter.to_le_bytes());
|
|
||||||
|
|
||||||
aead_chacha20_open_inner(buffer, key, nonce, data, aad)
|
|
||||||
.map_err(|_| WireGuardError::InvalidAeadTag)?;
|
|
||||||
Ok(())
|
|
||||||
}
|
|
||||||
|
|
||||||
#[inline]
|
|
||||||
fn aead_chacha20_open_inner(
|
|
||||||
buffer: &mut [u8],
|
|
||||||
key: &[u8],
|
|
||||||
nonce: [u8; 12],
|
|
||||||
data: &[u8],
|
|
||||||
aad: &[u8],
|
|
||||||
) -> Result<(), ring::error::Unspecified> {
|
|
||||||
let key = LessSafeKey::new(UnboundKey::new(&CHACHA20_POLY1305, key).unwrap());
|
|
||||||
|
|
||||||
let mut inner_buffer = data.to_owned();
|
|
||||||
|
|
||||||
let plaintext = key.open_in_place(
|
|
||||||
Nonce::assume_unique_for_key(nonce),
|
|
||||||
Aad::from(aad),
|
|
||||||
&mut inner_buffer,
|
|
||||||
)?;
|
|
||||||
|
|
||||||
buffer.copy_from_slice(plaintext);
|
|
||||||
|
|
||||||
Ok(())
|
|
||||||
}
|
|
||||||
|
|
||||||
#[derive(Debug)]
|
|
||||||
/// This struct represents a 12 byte [Tai64N](https://cr.yp.to/libtai/tai64.html) timestamp
|
|
||||||
struct Tai64N {
|
|
||||||
secs: u64,
|
|
||||||
nano: u32,
|
|
||||||
}
|
|
||||||
|
|
||||||
#[derive(Debug)]
|
|
||||||
/// This struct computes a [Tai64N](https://cr.yp.to/libtai/tai64.html) timestamp from current system time
|
|
||||||
struct TimeStamper {
|
|
||||||
duration_at_start: Duration,
|
|
||||||
instant_at_start: Instant,
|
|
||||||
}
|
|
||||||
|
|
||||||
impl TimeStamper {
|
|
||||||
/// Create a new TimeStamper
|
|
||||||
pub fn new() -> TimeStamper {
|
|
||||||
TimeStamper {
|
|
||||||
duration_at_start: SystemTime::now()
|
|
||||||
.duration_since(SystemTime::UNIX_EPOCH)
|
|
||||||
.unwrap(),
|
|
||||||
instant_at_start: Instant::now(),
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
/// Take time reading and generate a 12 byte timestamp
|
|
||||||
pub fn stamp(&self) -> [u8; 12] {
|
|
||||||
const TAI64_BASE: u64 = (1u64 << 62) + 37;
|
|
||||||
let mut ext_stamp = [0u8; 12];
|
|
||||||
let stamp = Instant::now().duration_since(self.instant_at_start) + self.duration_at_start;
|
|
||||||
ext_stamp[0..8].copy_from_slice(&(stamp.as_secs() + TAI64_BASE).to_be_bytes());
|
|
||||||
ext_stamp[8..12].copy_from_slice(&stamp.subsec_nanos().to_be_bytes());
|
|
||||||
ext_stamp
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
impl Tai64N {
|
|
||||||
/// A zeroed out timestamp
|
|
||||||
fn zero() -> Tai64N {
|
|
||||||
Tai64N { secs: 0, nano: 0 }
|
|
||||||
}
|
|
||||||
|
|
||||||
/// Parse a timestamp from a 12 byte u8 slice
|
|
||||||
fn parse(buf: &[u8; 12]) -> Result<Tai64N, WireGuardError> {
|
|
||||||
if buf.len() < 12 {
|
|
||||||
return Err(WireGuardError::InvalidTai64nTimestamp);
|
|
||||||
}
|
|
||||||
|
|
||||||
let (sec_bytes, nano_bytes) = buf.split_at(std::mem::size_of::<u64>());
|
|
||||||
let secs = u64::from_be_bytes(sec_bytes.try_into().unwrap());
|
|
||||||
let nano = u32::from_be_bytes(nano_bytes.try_into().unwrap());
|
|
||||||
|
|
||||||
// WireGuard does not actually expect tai64n timestamp, just monotonically increasing one
|
|
||||||
//if secs < (1u64 << 62) || secs >= (1u64 << 63) {
|
|
||||||
// return Err(WireGuardError::InvalidTai64nTimestamp);
|
|
||||||
//};
|
|
||||||
//if nano >= 1_000_000_000 {
|
|
||||||
// return Err(WireGuardError::InvalidTai64nTimestamp);
|
|
||||||
//}
|
|
||||||
|
|
||||||
Ok(Tai64N { secs, nano })
|
|
||||||
}
|
|
||||||
|
|
||||||
/// Check if this timestamp represents a time that is chronologically after the time represented
|
|
||||||
/// by the other timestamp
|
|
||||||
pub fn after(&self, other: &Tai64N) -> bool {
|
|
||||||
(self.secs > other.secs) || ((self.secs == other.secs) && (self.nano > other.nano))
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
/// Parameters used by the noise protocol
|
|
||||||
struct NoiseParams {
|
|
||||||
/// Our static public key
|
|
||||||
static_public: x25519::PublicKey,
|
|
||||||
/// Our static private key
|
|
||||||
static_private: x25519::StaticSecret,
|
|
||||||
/// Static public key of the other party
|
|
||||||
peer_static_public: x25519::PublicKey,
|
|
||||||
/// A shared key = DH(static_private, peer_static_public)
|
|
||||||
static_shared: x25519::SharedSecret,
|
|
||||||
/// A pre-computation of HASH("mac1----", peer_static_public) for this peer
|
|
||||||
sending_mac1_key: [u8; KEY_LEN],
|
|
||||||
/// An optional preshared key
|
|
||||||
preshared_key: Option<[u8; KEY_LEN]>,
|
|
||||||
}
|
|
||||||
|
|
||||||
impl std::fmt::Debug for NoiseParams {
|
|
||||||
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
|
|
||||||
f.debug_struct("NoiseParams")
|
|
||||||
.field("static_public", &self.static_public)
|
|
||||||
.field("static_private", &"<redacted>")
|
|
||||||
.field("peer_static_public", &self.peer_static_public)
|
|
||||||
.field("static_shared", &"<redacted>")
|
|
||||||
.field("sending_mac1_key", &self.sending_mac1_key)
|
|
||||||
.field("preshared_key", &self.preshared_key)
|
|
||||||
.finish()
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
struct HandshakeInitSentState {
|
|
||||||
local_index: u32,
|
|
||||||
hash: [u8; KEY_LEN],
|
|
||||||
chaining_key: [u8; KEY_LEN],
|
|
||||||
ephemeral_private: x25519::ReusableSecret,
|
|
||||||
time_sent: Instant,
|
|
||||||
}
|
|
||||||
|
|
||||||
impl std::fmt::Debug for HandshakeInitSentState {
|
|
||||||
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
|
|
||||||
f.debug_struct("HandshakeInitSentState")
|
|
||||||
.field("local_index", &self.local_index)
|
|
||||||
.field("hash", &self.hash)
|
|
||||||
.field("chaining_key", &self.chaining_key)
|
|
||||||
.field("ephemeral_private", &"<redacted>")
|
|
||||||
.field("time_sent", &self.time_sent)
|
|
||||||
.finish()
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
#[derive(Debug)]
|
|
||||||
enum HandshakeState {
|
|
||||||
/// No handshake in process
|
|
||||||
None,
|
|
||||||
/// We initiated the handshake
|
|
||||||
InitSent(HandshakeInitSentState),
|
|
||||||
/// Handshake initiated by peer
|
|
||||||
InitReceived {
|
|
||||||
hash: [u8; KEY_LEN],
|
|
||||||
chaining_key: [u8; KEY_LEN],
|
|
||||||
peer_ephemeral_public: x25519::PublicKey,
|
|
||||||
peer_index: u32,
|
|
||||||
},
|
|
||||||
/// Handshake was established too long ago (implies no handshake is in progress)
|
|
||||||
Expired,
|
|
||||||
}
|
|
||||||
|
|
||||||
pub struct Handshake {
|
|
||||||
params: NoiseParams,
|
|
||||||
/// Index of the next session
|
|
||||||
next_index: u32,
|
|
||||||
/// Allow to have two outgoing handshakes in flight, because sometimes we may receive a delayed response to a handshake with bad networks
|
|
||||||
previous: HandshakeState,
|
|
||||||
/// Current handshake state
|
|
||||||
state: HandshakeState,
|
|
||||||
cookies: Cookies,
|
|
||||||
/// The timestamp of the last handshake we received
|
|
||||||
last_handshake_timestamp: Tai64N,
|
|
||||||
// TODO: make TimeStamper a singleton
|
|
||||||
stamper: TimeStamper,
|
|
||||||
pub(super) last_rtt: Option<u32>,
|
|
||||||
}
|
|
||||||
|
|
||||||
#[derive(Default)]
|
|
||||||
struct Cookies {
|
|
||||||
last_mac1: Option<[u8; 16]>,
|
|
||||||
index: u32,
|
|
||||||
write_cookie: Option<[u8; 16]>,
|
|
||||||
}
|
|
||||||
|
|
||||||
#[derive(Debug)]
|
|
||||||
pub struct HalfHandshake {
|
|
||||||
pub peer_index: u32,
|
|
||||||
pub peer_static_public: [u8; 32],
|
|
||||||
}
|
|
||||||
|
|
||||||
pub fn parse_handshake_anon(
|
|
||||||
static_private: &x25519::StaticSecret,
|
|
||||||
static_public: &x25519::PublicKey,
|
|
||||||
packet: &HandshakeInit,
|
|
||||||
) -> Result<HalfHandshake, WireGuardError> {
|
|
||||||
let peer_index = packet.sender_idx;
|
|
||||||
// initiator.chaining_key = HASH(CONSTRUCTION)
|
|
||||||
let mut chaining_key = INITIAL_CHAIN_KEY;
|
|
||||||
// initiator.hash = HASH(HASH(initiator.chaining_key || IDENTIFIER) || responder.static_public)
|
|
||||||
let mut hash = INITIAL_CHAIN_HASH;
|
|
||||||
hash = b2s_hash(&hash, static_public.as_bytes());
|
|
||||||
// msg.unencrypted_ephemeral = DH_PUBKEY(initiator.ephemeral_private)
|
|
||||||
let peer_ephemeral_public = x25519::PublicKey::from(*packet.unencrypted_ephemeral);
|
|
||||||
// initiator.hash = HASH(initiator.hash || msg.unencrypted_ephemeral)
|
|
||||||
hash = b2s_hash(&hash, peer_ephemeral_public.as_bytes());
|
|
||||||
// temp = HMAC(initiator.chaining_key, msg.unencrypted_ephemeral)
|
|
||||||
// initiator.chaining_key = HMAC(temp, 0x1)
|
|
||||||
chaining_key = b2s_hmac(
|
|
||||||
&b2s_hmac(&chaining_key, peer_ephemeral_public.as_bytes()),
|
|
||||||
&[0x01],
|
|
||||||
);
|
|
||||||
// temp = HMAC(initiator.chaining_key, DH(initiator.ephemeral_private, responder.static_public))
|
|
||||||
let ephemeral_shared = static_private.diffie_hellman(&peer_ephemeral_public);
|
|
||||||
let temp = b2s_hmac(&chaining_key, &ephemeral_shared.to_bytes());
|
|
||||||
// initiator.chaining_key = HMAC(temp, 0x1)
|
|
||||||
chaining_key = b2s_hmac(&temp, &[0x01]);
|
|
||||||
// key = HMAC(temp, initiator.chaining_key || 0x2)
|
|
||||||
let key = b2s_hmac2(&temp, &chaining_key, &[0x02]);
|
|
||||||
|
|
||||||
let mut peer_static_public = [0u8; KEY_LEN];
|
|
||||||
// msg.encrypted_static = AEAD(key, 0, initiator.static_public, initiator.hash)
|
|
||||||
aead_chacha20_open(
|
|
||||||
&mut peer_static_public,
|
|
||||||
&key,
|
|
||||||
0,
|
|
||||||
packet.encrypted_static,
|
|
||||||
&hash,
|
|
||||||
)?;
|
|
||||||
|
|
||||||
Ok(HalfHandshake {
|
|
||||||
peer_index,
|
|
||||||
peer_static_public,
|
|
||||||
})
|
|
||||||
}
|
|
||||||
|
|
||||||
impl NoiseParams {
|
|
||||||
/// New noise params struct from our secret key, peers public key, and optional preshared key
|
|
||||||
fn new(
|
|
||||||
static_private: x25519::StaticSecret,
|
|
||||||
static_public: x25519::PublicKey,
|
|
||||||
peer_static_public: x25519::PublicKey,
|
|
||||||
preshared_key: Option<[u8; 32]>,
|
|
||||||
) -> Result<NoiseParams, WireGuardError> {
|
|
||||||
let static_shared = static_private.diffie_hellman(&peer_static_public);
|
|
||||||
|
|
||||||
let initial_sending_mac_key = b2s_hash(LABEL_MAC1, peer_static_public.as_bytes());
|
|
||||||
|
|
||||||
Ok(NoiseParams {
|
|
||||||
static_public,
|
|
||||||
static_private,
|
|
||||||
peer_static_public,
|
|
||||||
static_shared,
|
|
||||||
sending_mac1_key: initial_sending_mac_key,
|
|
||||||
preshared_key,
|
|
||||||
})
|
|
||||||
}
|
|
||||||
|
|
||||||
/// Set a new private key
|
|
||||||
fn set_static_private(
|
|
||||||
&mut self,
|
|
||||||
static_private: x25519::StaticSecret,
|
|
||||||
static_public: x25519::PublicKey,
|
|
||||||
) -> Result<(), WireGuardError> {
|
|
||||||
// Check that the public key indeed matches the private key
|
|
||||||
let check_key = x25519::PublicKey::from(&static_private);
|
|
||||||
assert_eq!(check_key.as_bytes(), static_public.as_bytes());
|
|
||||||
|
|
||||||
self.static_private = static_private;
|
|
||||||
self.static_public = static_public;
|
|
||||||
|
|
||||||
self.static_shared = self.static_private.diffie_hellman(&self.peer_static_public);
|
|
||||||
Ok(())
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
impl Handshake {
|
|
||||||
pub(crate) fn new(
|
|
||||||
static_private: x25519::StaticSecret,
|
|
||||||
static_public: x25519::PublicKey,
|
|
||||||
peer_static_public: x25519::PublicKey,
|
|
||||||
global_idx: u32,
|
|
||||||
preshared_key: Option<[u8; 32]>,
|
|
||||||
) -> Result<Handshake, WireGuardError> {
|
|
||||||
let params = NoiseParams::new(
|
|
||||||
static_private,
|
|
||||||
static_public,
|
|
||||||
peer_static_public,
|
|
||||||
preshared_key,
|
|
||||||
)?;
|
|
||||||
|
|
||||||
Ok(Handshake {
|
|
||||||
params,
|
|
||||||
next_index: global_idx,
|
|
||||||
previous: HandshakeState::None,
|
|
||||||
state: HandshakeState::None,
|
|
||||||
last_handshake_timestamp: Tai64N::zero(),
|
|
||||||
stamper: TimeStamper::new(),
|
|
||||||
cookies: Default::default(),
|
|
||||||
last_rtt: None,
|
|
||||||
})
|
|
||||||
}
|
|
||||||
|
|
||||||
pub(crate) fn is_in_progress(&self) -> bool {
|
|
||||||
!matches!(self.state, HandshakeState::None | HandshakeState::Expired)
|
|
||||||
}
|
|
||||||
|
|
||||||
pub(crate) fn timer(&self) -> Option<Instant> {
|
|
||||||
match self.state {
|
|
||||||
HandshakeState::InitSent(HandshakeInitSentState { time_sent, .. }) => Some(time_sent),
|
|
||||||
_ => None,
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
pub(crate) fn set_expired(&mut self) {
|
|
||||||
self.previous = HandshakeState::Expired;
|
|
||||||
self.state = HandshakeState::Expired;
|
|
||||||
}
|
|
||||||
|
|
||||||
pub(crate) fn is_expired(&self) -> bool {
|
|
||||||
matches!(self.state, HandshakeState::Expired)
|
|
||||||
}
|
|
||||||
|
|
||||||
pub(crate) fn has_cookie(&self) -> bool {
|
|
||||||
self.cookies.write_cookie.is_some()
|
|
||||||
}
|
|
||||||
|
|
||||||
pub(crate) fn clear_cookie(&mut self) {
|
|
||||||
self.cookies.write_cookie = None;
|
|
||||||
}
|
|
||||||
|
|
||||||
// The index used is 24 bits for peer index, allowing for 16M active peers per server and 8 bits for cyclic session index
|
|
||||||
fn inc_index(&mut self) -> u32 {
|
|
||||||
let index = self.next_index;
|
|
||||||
let idx8 = index as u8;
|
|
||||||
self.next_index = (index & !0xff) | u32::from(idx8.wrapping_add(1));
|
|
||||||
self.next_index
|
|
||||||
}
|
|
||||||
|
|
||||||
pub(crate) fn set_static_private(
|
|
||||||
&mut self,
|
|
||||||
private_key: x25519::StaticSecret,
|
|
||||||
public_key: x25519::PublicKey,
|
|
||||||
) -> Result<(), WireGuardError> {
|
|
||||||
self.params.set_static_private(private_key, public_key)
|
|
||||||
}
|
|
||||||
|
|
||||||
pub(super) fn receive_handshake_initialization<'a>(
|
|
||||||
&mut self,
|
|
||||||
packet: HandshakeInit,
|
|
||||||
dst: &'a mut [u8],
|
|
||||||
) -> Result<(&'a mut [u8], Session), WireGuardError> {
|
|
||||||
// initiator.chaining_key = HASH(CONSTRUCTION)
|
|
||||||
let mut chaining_key = INITIAL_CHAIN_KEY;
|
|
||||||
// initiator.hash = HASH(HASH(initiator.chaining_key || IDENTIFIER) || responder.static_public)
|
|
||||||
let mut hash = INITIAL_CHAIN_HASH;
|
|
||||||
hash = b2s_hash(&hash, self.params.static_public.as_bytes());
|
|
||||||
// msg.sender_index = little_endian(initiator.sender_index)
|
|
||||||
let peer_index = packet.sender_idx;
|
|
||||||
// msg.unencrypted_ephemeral = DH_PUBKEY(initiator.ephemeral_private)
|
|
||||||
let peer_ephemeral_public = x25519::PublicKey::from(*packet.unencrypted_ephemeral);
|
|
||||||
// initiator.hash = HASH(initiator.hash || msg.unencrypted_ephemeral)
|
|
||||||
hash = b2s_hash(&hash, peer_ephemeral_public.as_bytes());
|
|
||||||
// temp = HMAC(initiator.chaining_key, msg.unencrypted_ephemeral)
|
|
||||||
// initiator.chaining_key = HMAC(temp, 0x1)
|
|
||||||
chaining_key = b2s_hmac(
|
|
||||||
&b2s_hmac(&chaining_key, peer_ephemeral_public.as_bytes()),
|
|
||||||
&[0x01],
|
|
||||||
);
|
|
||||||
// temp = HMAC(initiator.chaining_key, DH(initiator.ephemeral_private, responder.static_public))
|
|
||||||
let ephemeral_shared = self
|
|
||||||
.params
|
|
||||||
.static_private
|
|
||||||
.diffie_hellman(&peer_ephemeral_public);
|
|
||||||
let temp = b2s_hmac(&chaining_key, &ephemeral_shared.to_bytes());
|
|
||||||
// initiator.chaining_key = HMAC(temp, 0x1)
|
|
||||||
chaining_key = b2s_hmac(&temp, &[0x01]);
|
|
||||||
// key = HMAC(temp, initiator.chaining_key || 0x2)
|
|
||||||
let key = b2s_hmac2(&temp, &chaining_key, &[0x02]);
|
|
||||||
|
|
||||||
let mut peer_static_public_decrypted = [0u8; KEY_LEN];
|
|
||||||
// msg.encrypted_static = AEAD(key, 0, initiator.static_public, initiator.hash)
|
|
||||||
aead_chacha20_open(
|
|
||||||
&mut peer_static_public_decrypted,
|
|
||||||
&key,
|
|
||||||
0,
|
|
||||||
packet.encrypted_static,
|
|
||||||
&hash,
|
|
||||||
)?;
|
|
||||||
|
|
||||||
ring::constant_time::verify_slices_are_equal(
|
|
||||||
self.params.peer_static_public.as_bytes(),
|
|
||||||
&peer_static_public_decrypted,
|
|
||||||
)
|
|
||||||
.map_err(|_| WireGuardError::WrongKey)?;
|
|
||||||
|
|
||||||
// initiator.hash = HASH(initiator.hash || msg.encrypted_static)
|
|
||||||
hash = b2s_hash(&hash, packet.encrypted_static);
|
|
||||||
// temp = HMAC(initiator.chaining_key, DH(initiator.static_private, responder.static_public))
|
|
||||||
let temp = b2s_hmac(&chaining_key, self.params.static_shared.as_bytes());
|
|
||||||
// initiator.chaining_key = HMAC(temp, 0x1)
|
|
||||||
chaining_key = b2s_hmac(&temp, &[0x01]);
|
|
||||||
// key = HMAC(temp, initiator.chaining_key || 0x2)
|
|
||||||
let key = b2s_hmac2(&temp, &chaining_key, &[0x02]);
|
|
||||||
// msg.encrypted_timestamp = AEAD(key, 0, TAI64N(), initiator.hash)
|
|
||||||
let mut timestamp = [0u8; TIMESTAMP_LEN];
|
|
||||||
aead_chacha20_open(&mut timestamp, &key, 0, packet.encrypted_timestamp, &hash)?;
|
|
||||||
|
|
||||||
let timestamp = Tai64N::parse(×tamp)?;
|
|
||||||
if !timestamp.after(&self.last_handshake_timestamp) {
|
|
||||||
// Possibly a replay
|
|
||||||
return Err(WireGuardError::WrongTai64nTimestamp);
|
|
||||||
}
|
|
||||||
self.last_handshake_timestamp = timestamp;
|
|
||||||
|
|
||||||
// initiator.hash = HASH(initiator.hash || msg.encrypted_timestamp)
|
|
||||||
hash = b2s_hash(&hash, packet.encrypted_timestamp);
|
|
||||||
|
|
||||||
self.previous = std::mem::replace(
|
|
||||||
&mut self.state,
|
|
||||||
HandshakeState::InitReceived {
|
|
||||||
chaining_key,
|
|
||||||
hash,
|
|
||||||
peer_ephemeral_public,
|
|
||||||
peer_index,
|
|
||||||
},
|
|
||||||
);
|
|
||||||
|
|
||||||
self.format_handshake_response(dst)
|
|
||||||
}
|
|
||||||
|
|
||||||
pub(super) fn receive_handshake_response(
|
|
||||||
&mut self,
|
|
||||||
packet: HandshakeResponse,
|
|
||||||
) -> Result<Session, WireGuardError> {
|
|
||||||
// Check if there is a handshake awaiting a response and return the correct one
|
|
||||||
let (state, is_previous) = match (&self.state, &self.previous) {
|
|
||||||
(HandshakeState::InitSent(s), _) if s.local_index == packet.receiver_idx => (s, false),
|
|
||||||
(_, HandshakeState::InitSent(s)) if s.local_index == packet.receiver_idx => (s, true),
|
|
||||||
_ => return Err(WireGuardError::UnexpectedPacket),
|
|
||||||
};
|
|
||||||
|
|
||||||
let peer_index = packet.sender_idx;
|
|
||||||
let local_index = state.local_index;
|
|
||||||
|
|
||||||
let unencrypted_ephemeral = x25519::PublicKey::from(*packet.unencrypted_ephemeral);
|
|
||||||
// msg.unencrypted_ephemeral = DH_PUBKEY(responder.ephemeral_private)
|
|
||||||
// responder.hash = HASH(responder.hash || msg.unencrypted_ephemeral)
|
|
||||||
let mut hash = b2s_hash(&state.hash, unencrypted_ephemeral.as_bytes());
|
|
||||||
// temp = HMAC(responder.chaining_key, msg.unencrypted_ephemeral)
|
|
||||||
let temp = b2s_hmac(&state.chaining_key, unencrypted_ephemeral.as_bytes());
|
|
||||||
// responder.chaining_key = HMAC(temp, 0x1)
|
|
||||||
let mut chaining_key = b2s_hmac(&temp, &[0x01]);
|
|
||||||
// temp = HMAC(responder.chaining_key, DH(responder.ephemeral_private, initiator.ephemeral_public))
|
|
||||||
let ephemeral_shared = state
|
|
||||||
.ephemeral_private
|
|
||||||
.diffie_hellman(&unencrypted_ephemeral);
|
|
||||||
let temp = b2s_hmac(&chaining_key, &ephemeral_shared.to_bytes());
|
|
||||||
// responder.chaining_key = HMAC(temp, 0x1)
|
|
||||||
chaining_key = b2s_hmac(&temp, &[0x01]);
|
|
||||||
// temp = HMAC(responder.chaining_key, DH(responder.ephemeral_private, initiator.static_public))
|
|
||||||
let temp = b2s_hmac(
|
|
||||||
&chaining_key,
|
|
||||||
&self
|
|
||||||
.params
|
|
||||||
.static_private
|
|
||||||
.diffie_hellman(&unencrypted_ephemeral)
|
|
||||||
.to_bytes(),
|
|
||||||
);
|
|
||||||
// responder.chaining_key = HMAC(temp, 0x1)
|
|
||||||
chaining_key = b2s_hmac(&temp, &[0x01]);
|
|
||||||
// temp = HMAC(responder.chaining_key, preshared_key)
|
|
||||||
let temp = b2s_hmac(
|
|
||||||
&chaining_key,
|
|
||||||
&self.params.preshared_key.unwrap_or([0u8; 32])[..],
|
|
||||||
);
|
|
||||||
// responder.chaining_key = HMAC(temp, 0x1)
|
|
||||||
chaining_key = b2s_hmac(&temp, &[0x01]);
|
|
||||||
// temp2 = HMAC(temp, responder.chaining_key || 0x2)
|
|
||||||
let temp2 = b2s_hmac2(&temp, &chaining_key, &[0x02]);
|
|
||||||
// key = HMAC(temp, temp2 || 0x3)
|
|
||||||
let key = b2s_hmac2(&temp, &temp2, &[0x03]);
|
|
||||||
// responder.hash = HASH(responder.hash || temp2)
|
|
||||||
hash = b2s_hash(&hash, &temp2);
|
|
||||||
// msg.encrypted_nothing = AEAD(key, 0, [empty], responder.hash)
|
|
||||||
aead_chacha20_open(&mut [], &key, 0, packet.encrypted_nothing, &hash)?;
|
|
||||||
|
|
||||||
// responder.hash = HASH(responder.hash || msg.encrypted_nothing)
|
|
||||||
// hash = b2s_hash(hash, buf[ENC_NOTHING_OFF..ENC_NOTHING_OFF + ENC_NOTHING_SZ]);
|
|
||||||
|
|
||||||
// Derive keys
|
|
||||||
// temp1 = HMAC(initiator.chaining_key, [empty])
|
|
||||||
// temp2 = HMAC(temp1, 0x1)
|
|
||||||
// temp3 = HMAC(temp1, temp2 || 0x2)
|
|
||||||
// initiator.sending_key = temp2
|
|
||||||
// initiator.receiving_key = temp3
|
|
||||||
// initiator.sending_key_counter = 0
|
|
||||||
// initiator.receiving_key_counter = 0
|
|
||||||
let temp1 = b2s_hmac(&chaining_key, &[]);
|
|
||||||
let temp2 = b2s_hmac(&temp1, &[0x01]);
|
|
||||||
let temp3 = b2s_hmac2(&temp1, &temp2, &[0x02]);
|
|
||||||
|
|
||||||
let rtt_time = Instant::now().duration_since(state.time_sent);
|
|
||||||
self.last_rtt = Some(rtt_time.as_millis() as u32);
|
|
||||||
|
|
||||||
if is_previous {
|
|
||||||
self.previous = HandshakeState::None;
|
|
||||||
} else {
|
|
||||||
self.state = HandshakeState::None;
|
|
||||||
}
|
|
||||||
Ok(Session::new(local_index, peer_index, temp3, temp2))
|
|
||||||
}
|
|
||||||
|
|
||||||
pub(super) fn receive_cookie_reply(
|
|
||||||
&mut self,
|
|
||||||
packet: PacketCookieReply,
|
|
||||||
) -> Result<(), WireGuardError> {
|
|
||||||
let mac1 = match self.cookies.last_mac1 {
|
|
||||||
Some(mac) => mac,
|
|
||||||
None => {
|
|
||||||
return Err(WireGuardError::UnexpectedPacket);
|
|
||||||
}
|
|
||||||
};
|
|
||||||
|
|
||||||
let local_index = self.cookies.index;
|
|
||||||
if packet.receiver_idx != local_index {
|
|
||||||
return Err(WireGuardError::WrongIndex);
|
|
||||||
}
|
|
||||||
// msg.encrypted_cookie = XAEAD(HASH(LABEL_COOKIE || responder.static_public), msg.nonce, cookie, last_received_msg.mac1)
|
|
||||||
let key = b2s_hash(LABEL_COOKIE, self.params.peer_static_public.as_bytes()); // TODO: pre-compute
|
|
||||||
|
|
||||||
let payload = Payload {
|
|
||||||
aad: &mac1[0..16],
|
|
||||||
msg: packet.encrypted_cookie,
|
|
||||||
};
|
|
||||||
let plaintext = XChaCha20Poly1305::new_from_slice(&key)
|
|
||||||
.unwrap()
|
|
||||||
.decrypt(packet.nonce.into(), payload)
|
|
||||||
.map_err(|_| WireGuardError::InvalidAeadTag)?;
|
|
||||||
|
|
||||||
let cookie = plaintext
|
|
||||||
.try_into()
|
|
||||||
.map_err(|_| WireGuardError::InvalidPacket)?;
|
|
||||||
self.cookies.write_cookie = Some(cookie);
|
|
||||||
Ok(())
|
|
||||||
}
|
|
||||||
|
|
||||||
// Compute and append mac1 and mac2 to a handshake message
|
|
||||||
fn append_mac1_and_mac2<'a>(
|
|
||||||
&mut self,
|
|
||||||
local_index: u32,
|
|
||||||
dst: &'a mut [u8],
|
|
||||||
) -> Result<&'a mut [u8], WireGuardError> {
|
|
||||||
let mac1_off = dst.len() - 32;
|
|
||||||
let mac2_off = dst.len() - 16;
|
|
||||||
|
|
||||||
// msg.mac1 = MAC(HASH(LABEL_MAC1 || responder.static_public), msg[0:offsetof(msg.mac1)])
|
|
||||||
let msg_mac1 = b2s_keyed_mac_16(&self.params.sending_mac1_key, &dst[..mac1_off]);
|
|
||||||
|
|
||||||
dst[mac1_off..mac2_off].copy_from_slice(&msg_mac1[..]);
|
|
||||||
|
|
||||||
//msg.mac2 = MAC(initiator.last_received_cookie, msg[0:offsetof(msg.mac2)])
|
|
||||||
let msg_mac2: [u8; 16] = if let Some(cookie) = self.cookies.write_cookie {
|
|
||||||
b2s_keyed_mac_16(&cookie, &dst[..mac2_off])
|
|
||||||
} else {
|
|
||||||
[0u8; 16]
|
|
||||||
};
|
|
||||||
|
|
||||||
dst[mac2_off..].copy_from_slice(&msg_mac2[..]);
|
|
||||||
|
|
||||||
self.cookies.index = local_index;
|
|
||||||
self.cookies.last_mac1 = Some(msg_mac1);
|
|
||||||
Ok(dst)
|
|
||||||
}
|
|
||||||
|
|
||||||
pub(super) fn format_handshake_initiation<'a>(
|
|
||||||
&mut self,
|
|
||||||
dst: &'a mut [u8],
|
|
||||||
) -> Result<&'a mut [u8], WireGuardError> {
|
|
||||||
if dst.len() < super::HANDSHAKE_INIT_SZ {
|
|
||||||
return Err(WireGuardError::DestinationBufferTooSmall);
|
|
||||||
}
|
|
||||||
|
|
||||||
let (message_type, rest) = dst.split_at_mut(4);
|
|
||||||
let (sender_index, rest) = rest.split_at_mut(4);
|
|
||||||
let (unencrypted_ephemeral, rest) = rest.split_at_mut(32);
|
|
||||||
let (encrypted_static, rest) = rest.split_at_mut(32 + 16);
|
|
||||||
let (encrypted_timestamp, _) = rest.split_at_mut(12 + 16);
|
|
||||||
|
|
||||||
let local_index = self.inc_index();
|
|
||||||
|
|
||||||
// initiator.chaining_key = HASH(CONSTRUCTION)
|
|
||||||
let mut chaining_key = INITIAL_CHAIN_KEY;
|
|
||||||
// initiator.hash = HASH(HASH(initiator.chaining_key || IDENTIFIER) || responder.static_public)
|
|
||||||
let mut hash = INITIAL_CHAIN_HASH;
|
|
||||||
hash = b2s_hash(&hash, self.params.peer_static_public.as_bytes());
|
|
||||||
// initiator.ephemeral_private = DH_GENERATE()
|
|
||||||
let ephemeral_private = x25519::ReusableSecret::random_from_rng(OsRng);
|
|
||||||
// msg.message_type = 1
|
|
||||||
// msg.reserved_zero = { 0, 0, 0 }
|
|
||||||
message_type.copy_from_slice(&super::HANDSHAKE_INIT.to_le_bytes());
|
|
||||||
// msg.sender_index = little_endian(initiator.sender_index)
|
|
||||||
sender_index.copy_from_slice(&local_index.to_le_bytes());
|
|
||||||
// msg.unencrypted_ephemeral = DH_PUBKEY(initiator.ephemeral_private)
|
|
||||||
unencrypted_ephemeral
|
|
||||||
.copy_from_slice(x25519::PublicKey::from(&ephemeral_private).as_bytes());
|
|
||||||
// initiator.hash = HASH(initiator.hash || msg.unencrypted_ephemeral)
|
|
||||||
hash = b2s_hash(&hash, unencrypted_ephemeral);
|
|
||||||
// temp = HMAC(initiator.chaining_key, msg.unencrypted_ephemeral)
|
|
||||||
// initiator.chaining_key = HMAC(temp, 0x1)
|
|
||||||
chaining_key = b2s_hmac(&b2s_hmac(&chaining_key, unencrypted_ephemeral), &[0x01]);
|
|
||||||
// temp = HMAC(initiator.chaining_key, DH(initiator.ephemeral_private, responder.static_public))
|
|
||||||
let ephemeral_shared = ephemeral_private.diffie_hellman(&self.params.peer_static_public);
|
|
||||||
let temp = b2s_hmac(&chaining_key, &ephemeral_shared.to_bytes());
|
|
||||||
// initiator.chaining_key = HMAC(temp, 0x1)
|
|
||||||
chaining_key = b2s_hmac(&temp, &[0x01]);
|
|
||||||
// key = HMAC(temp, initiator.chaining_key || 0x2)
|
|
||||||
let key = b2s_hmac2(&temp, &chaining_key, &[0x02]);
|
|
||||||
// msg.encrypted_static = AEAD(key, 0, initiator.static_public, initiator.hash)
|
|
||||||
aead_chacha20_seal(
|
|
||||||
encrypted_static,
|
|
||||||
&key,
|
|
||||||
0,
|
|
||||||
self.params.static_public.as_bytes(),
|
|
||||||
&hash,
|
|
||||||
);
|
|
||||||
// initiator.hash = HASH(initiator.hash || msg.encrypted_static)
|
|
||||||
hash = b2s_hash(&hash, encrypted_static);
|
|
||||||
// temp = HMAC(initiator.chaining_key, DH(initiator.static_private, responder.static_public))
|
|
||||||
let temp = b2s_hmac(&chaining_key, self.params.static_shared.as_bytes());
|
|
||||||
// initiator.chaining_key = HMAC(temp, 0x1)
|
|
||||||
chaining_key = b2s_hmac(&temp, &[0x01]);
|
|
||||||
// key = HMAC(temp, initiator.chaining_key || 0x2)
|
|
||||||
let key = b2s_hmac2(&temp, &chaining_key, &[0x02]);
|
|
||||||
// msg.encrypted_timestamp = AEAD(key, 0, TAI64N(), initiator.hash)
|
|
||||||
let timestamp = self.stamper.stamp();
|
|
||||||
aead_chacha20_seal(encrypted_timestamp, &key, 0, ×tamp, &hash);
|
|
||||||
// initiator.hash = HASH(initiator.hash || msg.encrypted_timestamp)
|
|
||||||
hash = b2s_hash(&hash, encrypted_timestamp);
|
|
||||||
|
|
||||||
let time_now = Instant::now();
|
|
||||||
self.previous = std::mem::replace(
|
|
||||||
&mut self.state,
|
|
||||||
HandshakeState::InitSent(HandshakeInitSentState {
|
|
||||||
local_index,
|
|
||||||
chaining_key,
|
|
||||||
hash,
|
|
||||||
ephemeral_private,
|
|
||||||
time_sent: time_now,
|
|
||||||
}),
|
|
||||||
);
|
|
||||||
|
|
||||||
self.append_mac1_and_mac2(local_index, &mut dst[..super::HANDSHAKE_INIT_SZ])
|
|
||||||
}
|
|
||||||
|
|
||||||
fn format_handshake_response<'a>(
|
|
||||||
&mut self,
|
|
||||||
dst: &'a mut [u8],
|
|
||||||
) -> Result<(&'a mut [u8], Session), WireGuardError> {
|
|
||||||
if dst.len() < super::HANDSHAKE_RESP_SZ {
|
|
||||||
return Err(WireGuardError::DestinationBufferTooSmall);
|
|
||||||
}
|
|
||||||
|
|
||||||
let state = std::mem::replace(&mut self.state, HandshakeState::None);
|
|
||||||
let (mut chaining_key, mut hash, peer_ephemeral_public, peer_index) = match state {
|
|
||||||
HandshakeState::InitReceived {
|
|
||||||
chaining_key,
|
|
||||||
hash,
|
|
||||||
peer_ephemeral_public,
|
|
||||||
peer_index,
|
|
||||||
} => (chaining_key, hash, peer_ephemeral_public, peer_index),
|
|
||||||
_ => {
|
|
||||||
panic!("Unexpected attempt to call send_handshake_response");
|
|
||||||
}
|
|
||||||
};
|
|
||||||
|
|
||||||
let (message_type, rest) = dst.split_at_mut(4);
|
|
||||||
let (sender_index, rest) = rest.split_at_mut(4);
|
|
||||||
let (receiver_index, rest) = rest.split_at_mut(4);
|
|
||||||
let (unencrypted_ephemeral, rest) = rest.split_at_mut(32);
|
|
||||||
let (encrypted_nothing, _) = rest.split_at_mut(16);
|
|
||||||
|
|
||||||
// responder.ephemeral_private = DH_GENERATE()
|
|
||||||
let ephemeral_private = x25519::ReusableSecret::random_from_rng(OsRng);
|
|
||||||
let local_index = self.inc_index();
|
|
||||||
// msg.message_type = 2
|
|
||||||
// msg.reserved_zero = { 0, 0, 0 }
|
|
||||||
message_type.copy_from_slice(&super::HANDSHAKE_RESP.to_le_bytes());
|
|
||||||
// msg.sender_index = little_endian(responder.sender_index)
|
|
||||||
sender_index.copy_from_slice(&local_index.to_le_bytes());
|
|
||||||
// msg.receiver_index = little_endian(initiator.sender_index)
|
|
||||||
receiver_index.copy_from_slice(&peer_index.to_le_bytes());
|
|
||||||
// msg.unencrypted_ephemeral = DH_PUBKEY(initiator.ephemeral_private)
|
|
||||||
unencrypted_ephemeral
|
|
||||||
.copy_from_slice(x25519::PublicKey::from(&ephemeral_private).as_bytes());
|
|
||||||
// responder.hash = HASH(responder.hash || msg.unencrypted_ephemeral)
|
|
||||||
hash = b2s_hash(&hash, unencrypted_ephemeral);
|
|
||||||
// temp = HMAC(responder.chaining_key, msg.unencrypted_ephemeral)
|
|
||||||
let temp = b2s_hmac(&chaining_key, unencrypted_ephemeral);
|
|
||||||
// responder.chaining_key = HMAC(temp, 0x1)
|
|
||||||
chaining_key = b2s_hmac(&temp, &[0x01]);
|
|
||||||
// temp = HMAC(responder.chaining_key, DH(responder.ephemeral_private, initiator.ephemeral_public))
|
|
||||||
let ephemeral_shared = ephemeral_private.diffie_hellman(&peer_ephemeral_public);
|
|
||||||
let temp = b2s_hmac(&chaining_key, &ephemeral_shared.to_bytes());
|
|
||||||
// responder.chaining_key = HMAC(temp, 0x1)
|
|
||||||
chaining_key = b2s_hmac(&temp, &[0x01]);
|
|
||||||
// temp = HMAC(responder.chaining_key, DH(responder.ephemeral_private, initiator.static_public))
|
|
||||||
let temp = b2s_hmac(
|
|
||||||
&chaining_key,
|
|
||||||
&ephemeral_private
|
|
||||||
.diffie_hellman(&self.params.peer_static_public)
|
|
||||||
.to_bytes(),
|
|
||||||
);
|
|
||||||
// responder.chaining_key = HMAC(temp, 0x1)
|
|
||||||
chaining_key = b2s_hmac(&temp, &[0x01]);
|
|
||||||
// temp = HMAC(responder.chaining_key, preshared_key)
|
|
||||||
let temp = b2s_hmac(
|
|
||||||
&chaining_key,
|
|
||||||
&self.params.preshared_key.unwrap_or([0u8; 32])[..],
|
|
||||||
);
|
|
||||||
// responder.chaining_key = HMAC(temp, 0x1)
|
|
||||||
chaining_key = b2s_hmac(&temp, &[0x01]);
|
|
||||||
// temp2 = HMAC(temp, responder.chaining_key || 0x2)
|
|
||||||
let temp2 = b2s_hmac2(&temp, &chaining_key, &[0x02]);
|
|
||||||
// key = HMAC(temp, temp2 || 0x3)
|
|
||||||
let key = b2s_hmac2(&temp, &temp2, &[0x03]);
|
|
||||||
// responder.hash = HASH(responder.hash || temp2)
|
|
||||||
hash = b2s_hash(&hash, &temp2);
|
|
||||||
// msg.encrypted_nothing = AEAD(key, 0, [empty], responder.hash)
|
|
||||||
aead_chacha20_seal(encrypted_nothing, &key, 0, &[], &hash);
|
|
||||||
|
|
||||||
// Derive keys
|
|
||||||
// temp1 = HMAC(initiator.chaining_key, [empty])
|
|
||||||
// temp2 = HMAC(temp1, 0x1)
|
|
||||||
// temp3 = HMAC(temp1, temp2 || 0x2)
|
|
||||||
// initiator.sending_key = temp2
|
|
||||||
// initiator.receiving_key = temp3
|
|
||||||
// initiator.sending_key_counter = 0
|
|
||||||
// initiator.receiving_key_counter = 0
|
|
||||||
let temp1 = b2s_hmac(&chaining_key, &[]);
|
|
||||||
let temp2 = b2s_hmac(&temp1, &[0x01]);
|
|
||||||
let temp3 = b2s_hmac2(&temp1, &temp2, &[0x02]);
|
|
||||||
|
|
||||||
let dst = self.append_mac1_and_mac2(local_index, &mut dst[..super::HANDSHAKE_RESP_SZ])?;
|
|
||||||
|
|
||||||
Ok((dst, Session::new(local_index, peer_index, temp2, temp3)))
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
#[cfg(test)]
|
|
||||||
mod tests {
|
|
||||||
use super::*;
|
|
||||||
|
|
||||||
#[test]
|
|
||||||
fn chacha20_seal_rfc7530_test_vector() {
|
|
||||||
let plaintext = b"Ladies and Gentlemen of the class of '99: If I could offer you only one tip for the future, sunscreen would be it.";
|
|
||||||
let aad: [u8; 12] = [
|
|
||||||
0x50, 0x51, 0x52, 0x53, 0xc0, 0xc1, 0xc2, 0xc3, 0xc4, 0xc5, 0xc6, 0xc7,
|
|
||||||
];
|
|
||||||
let key: [u8; 32] = [
|
|
||||||
0x80, 0x81, 0x82, 0x83, 0x84, 0x85, 0x86, 0x87, 0x88, 0x89, 0x8a, 0x8b, 0x8c, 0x8d,
|
|
||||||
0x8e, 0x8f, 0x90, 0x91, 0x92, 0x93, 0x94, 0x95, 0x96, 0x97, 0x98, 0x99, 0x9a, 0x9b,
|
|
||||||
0x9c, 0x9d, 0x9e, 0x9f,
|
|
||||||
];
|
|
||||||
let nonce: [u8; 12] = [
|
|
||||||
0x07, 0x00, 0x00, 0x00, 0x40, 0x41, 0x42, 0x43, 0x44, 0x45, 0x46, 0x47,
|
|
||||||
];
|
|
||||||
let mut buffer = vec![0; plaintext.len() + 16];
|
|
||||||
|
|
||||||
aead_chacha20_seal_inner(&mut buffer, &key, nonce, plaintext, &aad);
|
|
||||||
|
|
||||||
const EXPECTED_CIPHERTEXT: [u8; 114] = [
|
|
||||||
0xd3, 0x1a, 0x8d, 0x34, 0x64, 0x8e, 0x60, 0xdb, 0x7b, 0x86, 0xaf, 0xbc, 0x53, 0xef,
|
|
||||||
0x7e, 0xc2, 0xa4, 0xad, 0xed, 0x51, 0x29, 0x6e, 0x08, 0xfe, 0xa9, 0xe2, 0xb5, 0xa7,
|
|
||||||
0x36, 0xee, 0x62, 0xd6, 0x3d, 0xbe, 0xa4, 0x5e, 0x8c, 0xa9, 0x67, 0x12, 0x82, 0xfa,
|
|
||||||
0xfb, 0x69, 0xda, 0x92, 0x72, 0x8b, 0x1a, 0x71, 0xde, 0x0a, 0x9e, 0x06, 0x0b, 0x29,
|
|
||||||
0x05, 0xd6, 0xa5, 0xb6, 0x7e, 0xcd, 0x3b, 0x36, 0x92, 0xdd, 0xbd, 0x7f, 0x2d, 0x77,
|
|
||||||
0x8b, 0x8c, 0x98, 0x03, 0xae, 0xe3, 0x28, 0x09, 0x1b, 0x58, 0xfa, 0xb3, 0x24, 0xe4,
|
|
||||||
0xfa, 0xd6, 0x75, 0x94, 0x55, 0x85, 0x80, 0x8b, 0x48, 0x31, 0xd7, 0xbc, 0x3f, 0xf4,
|
|
||||||
0xde, 0xf0, 0x8e, 0x4b, 0x7a, 0x9d, 0xe5, 0x76, 0xd2, 0x65, 0x86, 0xce, 0xc6, 0x4b,
|
|
||||||
0x61, 0x16,
|
|
||||||
];
|
|
||||||
const EXPECTED_TAG: [u8; 16] = [
|
|
||||||
0x1a, 0xe1, 0x0b, 0x59, 0x4f, 0x09, 0xe2, 0x6a, 0x7e, 0x90, 0x2e, 0xcb, 0xd0, 0x60,
|
|
||||||
0x06, 0x91,
|
|
||||||
];
|
|
||||||
|
|
||||||
assert_eq!(buffer[..plaintext.len()], EXPECTED_CIPHERTEXT);
|
|
||||||
assert_eq!(buffer[plaintext.len()..], EXPECTED_TAG);
|
|
||||||
}
|
|
||||||
|
|
||||||
#[test]
|
|
||||||
fn symmetric_chacha20_seal_open() {
|
|
||||||
let aad: [u8; 32] = Default::default();
|
|
||||||
let key: [u8; 32] = Default::default();
|
|
||||||
let counter = 0;
|
|
||||||
|
|
||||||
let mut encrypted_nothing: [u8; 16] = Default::default();
|
|
||||||
|
|
||||||
aead_chacha20_seal(&mut encrypted_nothing, &key, counter, &[], &aad);
|
|
||||||
|
|
||||||
eprintln!("encrypted_nothing: {:?}", encrypted_nothing);
|
|
||||||
|
|
||||||
aead_chacha20_open(&mut [], &key, counter, &encrypted_nothing, &aad)
|
|
||||||
.expect("Should open what we just sealed");
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
@ -1,799 +0,0 @@
|
||||||
// Copyright (c) 2019 Cloudflare, Inc. All rights reserved.
|
|
||||||
// SPDX-License-Identifier: BSD-3-Clause
|
|
||||||
|
|
||||||
pub mod errors;
|
|
||||||
pub mod handshake;
|
|
||||||
pub mod rate_limiter;
|
|
||||||
|
|
||||||
mod session;
|
|
||||||
mod timers;
|
|
||||||
|
|
||||||
use crate::noise::errors::WireGuardError;
|
|
||||||
use crate::noise::handshake::Handshake;
|
|
||||||
use crate::noise::rate_limiter::RateLimiter;
|
|
||||||
use crate::noise::timers::{TimerName, Timers};
|
|
||||||
use crate::x25519;
|
|
||||||
|
|
||||||
use std::collections::VecDeque;
|
|
||||||
use std::convert::{TryFrom, TryInto};
|
|
||||||
use std::net::{IpAddr, Ipv4Addr, Ipv6Addr};
|
|
||||||
use std::sync::Arc;
|
|
||||||
use std::time::Duration;
|
|
||||||
|
|
||||||
/// The default value to use for rate limiting, when no other rate limiter is defined
|
|
||||||
const PEER_HANDSHAKE_RATE_LIMIT: u64 = 10;
|
|
||||||
|
|
||||||
const IPV4_MIN_HEADER_SIZE: usize = 20;
|
|
||||||
const IPV4_LEN_OFF: usize = 2;
|
|
||||||
const IPV4_SRC_IP_OFF: usize = 12;
|
|
||||||
const IPV4_DST_IP_OFF: usize = 16;
|
|
||||||
const IPV4_IP_SZ: usize = 4;
|
|
||||||
|
|
||||||
const IPV6_MIN_HEADER_SIZE: usize = 40;
|
|
||||||
const IPV6_LEN_OFF: usize = 4;
|
|
||||||
const IPV6_SRC_IP_OFF: usize = 8;
|
|
||||||
const IPV6_DST_IP_OFF: usize = 24;
|
|
||||||
const IPV6_IP_SZ: usize = 16;
|
|
||||||
|
|
||||||
const IP_LEN_SZ: usize = 2;
|
|
||||||
|
|
||||||
const MAX_QUEUE_DEPTH: usize = 256;
|
|
||||||
/// number of sessions in the ring, better keep a PoT
|
|
||||||
const N_SESSIONS: usize = 8;
|
|
||||||
|
|
||||||
#[derive(Debug)]
|
|
||||||
pub enum TunnResult<'a> {
|
|
||||||
Done,
|
|
||||||
Err(WireGuardError),
|
|
||||||
WriteToNetwork(&'a mut [u8]),
|
|
||||||
WriteToTunnelV4(&'a mut [u8], Ipv4Addr),
|
|
||||||
WriteToTunnelV6(&'a mut [u8], Ipv6Addr),
|
|
||||||
}
|
|
||||||
|
|
||||||
impl<'a> From<WireGuardError> for TunnResult<'a> {
|
|
||||||
fn from(err: WireGuardError) -> TunnResult<'a> {
|
|
||||||
TunnResult::Err(err)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
/// Tunnel represents a point-to-point WireGuard connection
|
|
||||||
pub struct Tunn {
|
|
||||||
/// The handshake currently in progress
|
|
||||||
handshake: handshake::Handshake,
|
|
||||||
/// The N_SESSIONS most recent sessions, index is session id modulo N_SESSIONS
|
|
||||||
sessions: [Option<session::Session>; N_SESSIONS],
|
|
||||||
/// Index of most recently used session
|
|
||||||
current: usize,
|
|
||||||
/// Queue to store blocked packets
|
|
||||||
packet_queue: VecDeque<Vec<u8>>,
|
|
||||||
/// Keeps tabs on the expiring timers
|
|
||||||
timers: timers::Timers,
|
|
||||||
tx_bytes: usize,
|
|
||||||
rx_bytes: usize,
|
|
||||||
rate_limiter: Arc<RateLimiter>,
|
|
||||||
}
|
|
||||||
|
|
||||||
type MessageType = u32;
|
|
||||||
const HANDSHAKE_INIT: MessageType = 1;
|
|
||||||
const HANDSHAKE_RESP: MessageType = 2;
|
|
||||||
const COOKIE_REPLY: MessageType = 3;
|
|
||||||
const DATA: MessageType = 4;
|
|
||||||
|
|
||||||
const HANDSHAKE_INIT_SZ: usize = 148;
|
|
||||||
const HANDSHAKE_RESP_SZ: usize = 92;
|
|
||||||
const COOKIE_REPLY_SZ: usize = 64;
|
|
||||||
const DATA_OVERHEAD_SZ: usize = 32;
|
|
||||||
|
|
||||||
#[derive(Debug)]
|
|
||||||
pub struct HandshakeInit<'a> {
|
|
||||||
sender_idx: u32,
|
|
||||||
unencrypted_ephemeral: &'a [u8; 32],
|
|
||||||
encrypted_static: &'a [u8],
|
|
||||||
encrypted_timestamp: &'a [u8],
|
|
||||||
}
|
|
||||||
|
|
||||||
#[derive(Debug)]
|
|
||||||
pub struct HandshakeResponse<'a> {
|
|
||||||
sender_idx: u32,
|
|
||||||
pub receiver_idx: u32,
|
|
||||||
unencrypted_ephemeral: &'a [u8; 32],
|
|
||||||
encrypted_nothing: &'a [u8],
|
|
||||||
}
|
|
||||||
|
|
||||||
#[derive(Debug)]
|
|
||||||
pub struct PacketCookieReply<'a> {
|
|
||||||
pub receiver_idx: u32,
|
|
||||||
nonce: &'a [u8],
|
|
||||||
encrypted_cookie: &'a [u8],
|
|
||||||
}
|
|
||||||
|
|
||||||
#[derive(Debug)]
|
|
||||||
pub struct PacketData<'a> {
|
|
||||||
pub receiver_idx: u32,
|
|
||||||
counter: u64,
|
|
||||||
encrypted_encapsulated_packet: &'a [u8],
|
|
||||||
}
|
|
||||||
|
|
||||||
/// Describes a packet from network
|
|
||||||
#[derive(Debug)]
|
|
||||||
pub enum Packet<'a> {
|
|
||||||
HandshakeInit(HandshakeInit<'a>),
|
|
||||||
HandshakeResponse(HandshakeResponse<'a>),
|
|
||||||
PacketCookieReply(PacketCookieReply<'a>),
|
|
||||||
PacketData(PacketData<'a>),
|
|
||||||
}
|
|
||||||
|
|
||||||
impl Tunn {
|
|
||||||
#[inline(always)]
|
|
||||||
pub fn parse_incoming_packet(src: &[u8]) -> Result<Packet, WireGuardError> {
|
|
||||||
if src.len() < 4 {
|
|
||||||
return Err(WireGuardError::InvalidPacket);
|
|
||||||
}
|
|
||||||
|
|
||||||
// Checks the type, as well as the reserved zero fields
|
|
||||||
let packet_type = u32::from_le_bytes(src[0..4].try_into().unwrap());
|
|
||||||
|
|
||||||
Ok(match (packet_type, src.len()) {
|
|
||||||
(HANDSHAKE_INIT, HANDSHAKE_INIT_SZ) => Packet::HandshakeInit(HandshakeInit {
|
|
||||||
sender_idx: u32::from_le_bytes(src[4..8].try_into().unwrap()),
|
|
||||||
unencrypted_ephemeral: <&[u8; 32] as TryFrom<&[u8]>>::try_from(&src[8..40])
|
|
||||||
.expect("length already checked above"),
|
|
||||||
encrypted_static: &src[40..88],
|
|
||||||
encrypted_timestamp: &src[88..116],
|
|
||||||
}),
|
|
||||||
(HANDSHAKE_RESP, HANDSHAKE_RESP_SZ) => Packet::HandshakeResponse(HandshakeResponse {
|
|
||||||
sender_idx: u32::from_le_bytes(src[4..8].try_into().unwrap()),
|
|
||||||
receiver_idx: u32::from_le_bytes(src[8..12].try_into().unwrap()),
|
|
||||||
unencrypted_ephemeral: <&[u8; 32] as TryFrom<&[u8]>>::try_from(&src[12..44])
|
|
||||||
.expect("length already checked above"),
|
|
||||||
encrypted_nothing: &src[44..60],
|
|
||||||
}),
|
|
||||||
(COOKIE_REPLY, COOKIE_REPLY_SZ) => Packet::PacketCookieReply(PacketCookieReply {
|
|
||||||
receiver_idx: u32::from_le_bytes(src[4..8].try_into().unwrap()),
|
|
||||||
nonce: &src[8..32],
|
|
||||||
encrypted_cookie: &src[32..64],
|
|
||||||
}),
|
|
||||||
(DATA, DATA_OVERHEAD_SZ..=std::usize::MAX) => Packet::PacketData(PacketData {
|
|
||||||
receiver_idx: u32::from_le_bytes(src[4..8].try_into().unwrap()),
|
|
||||||
counter: u64::from_le_bytes(src[8..16].try_into().unwrap()),
|
|
||||||
encrypted_encapsulated_packet: &src[16..],
|
|
||||||
}),
|
|
||||||
_ => return Err(WireGuardError::InvalidPacket),
|
|
||||||
})
|
|
||||||
}
|
|
||||||
|
|
||||||
pub fn is_expired(&self) -> bool {
|
|
||||||
self.handshake.is_expired()
|
|
||||||
}
|
|
||||||
|
|
||||||
pub fn dst_address(packet: &[u8]) -> Option<IpAddr> {
|
|
||||||
if packet.is_empty() {
|
|
||||||
return None;
|
|
||||||
}
|
|
||||||
|
|
||||||
match packet[0] >> 4 {
|
|
||||||
4 if packet.len() >= IPV4_MIN_HEADER_SIZE => {
|
|
||||||
let addr_bytes: [u8; IPV4_IP_SZ] = packet
|
|
||||||
[IPV4_DST_IP_OFF..IPV4_DST_IP_OFF + IPV4_IP_SZ]
|
|
||||||
.try_into()
|
|
||||||
.unwrap();
|
|
||||||
Some(IpAddr::from(addr_bytes))
|
|
||||||
}
|
|
||||||
6 if packet.len() >= IPV6_MIN_HEADER_SIZE => {
|
|
||||||
let addr_bytes: [u8; IPV6_IP_SZ] = packet
|
|
||||||
[IPV6_DST_IP_OFF..IPV6_DST_IP_OFF + IPV6_IP_SZ]
|
|
||||||
.try_into()
|
|
||||||
.unwrap();
|
|
||||||
Some(IpAddr::from(addr_bytes))
|
|
||||||
}
|
|
||||||
_ => None,
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
/// Create a new tunnel using own private key and the peer public key
|
|
||||||
pub fn new(
|
|
||||||
static_private: x25519::StaticSecret,
|
|
||||||
peer_static_public: x25519::PublicKey,
|
|
||||||
preshared_key: Option<[u8; 32]>,
|
|
||||||
persistent_keepalive: Option<u16>,
|
|
||||||
index: u32,
|
|
||||||
rate_limiter: Option<Arc<RateLimiter>>,
|
|
||||||
) -> Result<Self, &'static str> {
|
|
||||||
let static_public = x25519::PublicKey::from(&static_private);
|
|
||||||
|
|
||||||
let tunn = Tunn {
|
|
||||||
handshake: Handshake::new(
|
|
||||||
static_private,
|
|
||||||
static_public,
|
|
||||||
peer_static_public,
|
|
||||||
index << 8,
|
|
||||||
preshared_key,
|
|
||||||
)
|
|
||||||
.map_err(|_| "Invalid parameters")?,
|
|
||||||
sessions: Default::default(),
|
|
||||||
current: Default::default(),
|
|
||||||
tx_bytes: Default::default(),
|
|
||||||
rx_bytes: Default::default(),
|
|
||||||
|
|
||||||
packet_queue: VecDeque::new(),
|
|
||||||
timers: Timers::new(persistent_keepalive, rate_limiter.is_none()),
|
|
||||||
|
|
||||||
rate_limiter: rate_limiter.unwrap_or_else(|| {
|
|
||||||
Arc::new(RateLimiter::new(&static_public, PEER_HANDSHAKE_RATE_LIMIT))
|
|
||||||
}),
|
|
||||||
};
|
|
||||||
|
|
||||||
Ok(tunn)
|
|
||||||
}
|
|
||||||
|
|
||||||
/// Update the private key and clear existing sessions
|
|
||||||
pub fn set_static_private(
|
|
||||||
&mut self,
|
|
||||||
static_private: x25519::StaticSecret,
|
|
||||||
static_public: x25519::PublicKey,
|
|
||||||
rate_limiter: Option<Arc<RateLimiter>>,
|
|
||||||
) -> Result<(), WireGuardError> {
|
|
||||||
self.timers.should_reset_rr = rate_limiter.is_none();
|
|
||||||
self.rate_limiter = rate_limiter.unwrap_or_else(|| {
|
|
||||||
Arc::new(RateLimiter::new(&static_public, PEER_HANDSHAKE_RATE_LIMIT))
|
|
||||||
});
|
|
||||||
self.handshake
|
|
||||||
.set_static_private(static_private, static_public)?;
|
|
||||||
for s in &mut self.sessions {
|
|
||||||
*s = None;
|
|
||||||
}
|
|
||||||
Ok(())
|
|
||||||
}
|
|
||||||
|
|
||||||
/// Encapsulate a single packet from the tunnel interface.
|
|
||||||
/// Returns TunnResult.
|
|
||||||
///
|
|
||||||
/// # Panics
|
|
||||||
/// Panics if dst buffer is too small.
|
|
||||||
/// Size of dst should be at least src.len() + 32, and no less than 148 bytes.
|
|
||||||
pub fn encapsulate<'a>(&mut self, src: &[u8], dst: &'a mut [u8]) -> TunnResult<'a> {
|
|
||||||
let current = self.current;
|
|
||||||
if let Some(ref session) = self.sessions[current % N_SESSIONS] {
|
|
||||||
// Send the packet using an established session
|
|
||||||
let packet = session.format_packet_data(src, dst);
|
|
||||||
self.timer_tick(TimerName::TimeLastPacketSent);
|
|
||||||
// Exclude Keepalive packets from timer update.
|
|
||||||
if !src.is_empty() {
|
|
||||||
self.timer_tick(TimerName::TimeLastDataPacketSent);
|
|
||||||
}
|
|
||||||
self.tx_bytes += src.len();
|
|
||||||
return TunnResult::WriteToNetwork(packet);
|
|
||||||
}
|
|
||||||
|
|
||||||
// If there is no session, queue the packet for future retry
|
|
||||||
self.queue_packet(src);
|
|
||||||
// Initiate a new handshake if none is in progress
|
|
||||||
self.format_handshake_initiation(dst, false)
|
|
||||||
}
|
|
||||||
|
|
||||||
/// Receives a UDP datagram from the network and parses it.
|
|
||||||
/// Returns TunnResult.
|
|
||||||
///
|
|
||||||
/// If the result is of type TunnResult::WriteToNetwork, should repeat the call with empty datagram,
|
|
||||||
/// until TunnResult::Done is returned. If batch processing packets, it is OK to defer until last
|
|
||||||
/// packet is processed.
|
|
||||||
pub fn decapsulate<'a>(
|
|
||||||
&mut self,
|
|
||||||
src_addr: Option<IpAddr>,
|
|
||||||
datagram: &[u8],
|
|
||||||
dst: &'a mut [u8],
|
|
||||||
) -> TunnResult<'a> {
|
|
||||||
if datagram.is_empty() {
|
|
||||||
// Indicates a repeated call
|
|
||||||
return self.send_queued_packet(dst);
|
|
||||||
}
|
|
||||||
|
|
||||||
let mut cookie = [0u8; COOKIE_REPLY_SZ];
|
|
||||||
let packet = match self
|
|
||||||
.rate_limiter
|
|
||||||
.verify_packet(src_addr, datagram, &mut cookie)
|
|
||||||
{
|
|
||||||
Ok(packet) => packet,
|
|
||||||
Err(TunnResult::WriteToNetwork(cookie)) => {
|
|
||||||
dst[..cookie.len()].copy_from_slice(cookie);
|
|
||||||
return TunnResult::WriteToNetwork(&mut dst[..cookie.len()]);
|
|
||||||
}
|
|
||||||
Err(TunnResult::Err(e)) => return TunnResult::Err(e),
|
|
||||||
_ => unreachable!(),
|
|
||||||
};
|
|
||||||
|
|
||||||
self.handle_verified_packet(packet, dst)
|
|
||||||
}
|
|
||||||
|
|
||||||
pub(crate) fn handle_verified_packet<'a>(
|
|
||||||
&mut self,
|
|
||||||
packet: Packet,
|
|
||||||
dst: &'a mut [u8],
|
|
||||||
) -> TunnResult<'a> {
|
|
||||||
match packet {
|
|
||||||
Packet::HandshakeInit(p) => self.handle_handshake_init(p, dst),
|
|
||||||
Packet::HandshakeResponse(p) => self.handle_handshake_response(p, dst),
|
|
||||||
Packet::PacketCookieReply(p) => self.handle_cookie_reply(p),
|
|
||||||
Packet::PacketData(p) => self.handle_data(p, dst),
|
|
||||||
}
|
|
||||||
.unwrap_or_else(TunnResult::from)
|
|
||||||
}
|
|
||||||
|
|
||||||
fn handle_handshake_init<'a>(
|
|
||||||
&mut self,
|
|
||||||
p: HandshakeInit,
|
|
||||||
dst: &'a mut [u8],
|
|
||||||
) -> Result<TunnResult<'a>, WireGuardError> {
|
|
||||||
tracing::debug!(
|
|
||||||
message = "Received handshake_initiation",
|
|
||||||
remote_idx = p.sender_idx
|
|
||||||
);
|
|
||||||
|
|
||||||
let (packet, session) = self.handshake.receive_handshake_initialization(p, dst)?;
|
|
||||||
|
|
||||||
// Store new session in ring buffer
|
|
||||||
let index = session.local_index();
|
|
||||||
self.sessions[index % N_SESSIONS] = Some(session);
|
|
||||||
|
|
||||||
self.timer_tick(TimerName::TimeLastPacketReceived);
|
|
||||||
self.timer_tick(TimerName::TimeLastPacketSent);
|
|
||||||
self.timer_tick_session_established(false, index); // New session established, we are not the initiator
|
|
||||||
|
|
||||||
tracing::debug!(message = "Sending handshake_response", local_idx = index);
|
|
||||||
|
|
||||||
Ok(TunnResult::WriteToNetwork(packet))
|
|
||||||
}
|
|
||||||
|
|
||||||
fn handle_handshake_response<'a>(
|
|
||||||
&mut self,
|
|
||||||
p: HandshakeResponse,
|
|
||||||
dst: &'a mut [u8],
|
|
||||||
) -> Result<TunnResult<'a>, WireGuardError> {
|
|
||||||
tracing::debug!(
|
|
||||||
message = "Received handshake_response",
|
|
||||||
local_idx = p.receiver_idx,
|
|
||||||
remote_idx = p.sender_idx
|
|
||||||
);
|
|
||||||
|
|
||||||
let session = self.handshake.receive_handshake_response(p)?;
|
|
||||||
|
|
||||||
let keepalive_packet = session.format_packet_data(&[], dst);
|
|
||||||
// Store new session in ring buffer
|
|
||||||
let l_idx = session.local_index();
|
|
||||||
let index = l_idx % N_SESSIONS;
|
|
||||||
self.sessions[index] = Some(session);
|
|
||||||
|
|
||||||
self.timer_tick(TimerName::TimeLastPacketReceived);
|
|
||||||
self.timer_tick_session_established(true, index); // New session established, we are the initiator
|
|
||||||
self.set_current_session(l_idx);
|
|
||||||
|
|
||||||
tracing::debug!("Sending keepalive");
|
|
||||||
|
|
||||||
Ok(TunnResult::WriteToNetwork(keepalive_packet)) // Send a keepalive as a response
|
|
||||||
}
|
|
||||||
|
|
||||||
fn handle_cookie_reply<'a>(
|
|
||||||
&mut self,
|
|
||||||
p: PacketCookieReply,
|
|
||||||
) -> Result<TunnResult<'a>, WireGuardError> {
|
|
||||||
tracing::debug!(
|
|
||||||
message = "Received cookie_reply",
|
|
||||||
local_idx = p.receiver_idx
|
|
||||||
);
|
|
||||||
|
|
||||||
self.handshake.receive_cookie_reply(p)?;
|
|
||||||
self.timer_tick(TimerName::TimeLastPacketReceived);
|
|
||||||
self.timer_tick(TimerName::TimeCookieReceived);
|
|
||||||
|
|
||||||
tracing::debug!("Did set cookie");
|
|
||||||
|
|
||||||
Ok(TunnResult::Done)
|
|
||||||
}
|
|
||||||
|
|
||||||
/// Update the index of the currently used session, if needed
|
|
||||||
fn set_current_session(&mut self, new_idx: usize) {
|
|
||||||
let cur_idx = self.current;
|
|
||||||
if cur_idx == new_idx {
|
|
||||||
// There is nothing to do, already using this session, this is the common case
|
|
||||||
return;
|
|
||||||
}
|
|
||||||
if self.sessions[cur_idx % N_SESSIONS].is_none()
|
|
||||||
|| self.timers.session_timers[new_idx % N_SESSIONS]
|
|
||||||
>= self.timers.session_timers[cur_idx % N_SESSIONS]
|
|
||||||
{
|
|
||||||
self.current = new_idx;
|
|
||||||
tracing::debug!(message = "New session", session = new_idx);
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
/// Decrypts a data packet, and stores the decapsulated packet in dst.
|
|
||||||
fn handle_data<'a>(
|
|
||||||
&mut self,
|
|
||||||
packet: PacketData,
|
|
||||||
dst: &'a mut [u8],
|
|
||||||
) -> Result<TunnResult<'a>, WireGuardError> {
|
|
||||||
let r_idx = packet.receiver_idx as usize;
|
|
||||||
let idx = r_idx % N_SESSIONS;
|
|
||||||
|
|
||||||
// Get the (probably) right session
|
|
||||||
let decapsulated_packet = {
|
|
||||||
let session = self.sessions[idx].as_ref();
|
|
||||||
let session = session.ok_or_else(|| {
|
|
||||||
tracing::trace!(message = "No current session available", remote_idx = r_idx);
|
|
||||||
WireGuardError::NoCurrentSession
|
|
||||||
})?;
|
|
||||||
session.receive_packet_data(packet, dst)?
|
|
||||||
};
|
|
||||||
|
|
||||||
self.set_current_session(r_idx);
|
|
||||||
|
|
||||||
self.timer_tick(TimerName::TimeLastPacketReceived);
|
|
||||||
|
|
||||||
Ok(self.validate_decapsulated_packet(decapsulated_packet))
|
|
||||||
}
|
|
||||||
|
|
||||||
/// Formats a new handshake initiation message and store it in dst. If force_resend is true will send
|
|
||||||
/// a new handshake, even if a handshake is already in progress (for example when a handshake times out)
|
|
||||||
pub fn format_handshake_initiation<'a>(
|
|
||||||
&mut self,
|
|
||||||
dst: &'a mut [u8],
|
|
||||||
force_resend: bool,
|
|
||||||
) -> TunnResult<'a> {
|
|
||||||
if self.handshake.is_in_progress() && !force_resend {
|
|
||||||
return TunnResult::Done;
|
|
||||||
}
|
|
||||||
|
|
||||||
if self.handshake.is_expired() {
|
|
||||||
self.timers.clear();
|
|
||||||
}
|
|
||||||
|
|
||||||
let starting_new_handshake = !self.handshake.is_in_progress();
|
|
||||||
|
|
||||||
match self.handshake.format_handshake_initiation(dst) {
|
|
||||||
Ok(packet) => {
|
|
||||||
tracing::debug!("Sending handshake_initiation");
|
|
||||||
|
|
||||||
if starting_new_handshake {
|
|
||||||
self.timer_tick(TimerName::TimeLastHandshakeStarted);
|
|
||||||
}
|
|
||||||
self.timer_tick(TimerName::TimeLastPacketSent);
|
|
||||||
TunnResult::WriteToNetwork(packet)
|
|
||||||
}
|
|
||||||
Err(e) => TunnResult::Err(e),
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
/// Check if an IP packet is v4 or v6, truncate to the length indicated by the length field
|
|
||||||
/// Returns the truncated packet and the source IP as TunnResult
|
|
||||||
fn validate_decapsulated_packet<'a>(&mut self, packet: &'a mut [u8]) -> TunnResult<'a> {
|
|
||||||
let (computed_len, src_ip_address) = match packet.len() {
|
|
||||||
0 => return TunnResult::Done, // This is keepalive, and not an error
|
|
||||||
_ if packet[0] >> 4 == 4 && packet.len() >= IPV4_MIN_HEADER_SIZE => {
|
|
||||||
let len_bytes: [u8; IP_LEN_SZ] = packet[IPV4_LEN_OFF..IPV4_LEN_OFF + IP_LEN_SZ]
|
|
||||||
.try_into()
|
|
||||||
.unwrap();
|
|
||||||
let addr_bytes: [u8; IPV4_IP_SZ] = packet
|
|
||||||
[IPV4_SRC_IP_OFF..IPV4_SRC_IP_OFF + IPV4_IP_SZ]
|
|
||||||
.try_into()
|
|
||||||
.unwrap();
|
|
||||||
(
|
|
||||||
u16::from_be_bytes(len_bytes) as usize,
|
|
||||||
IpAddr::from(addr_bytes),
|
|
||||||
)
|
|
||||||
}
|
|
||||||
_ if packet[0] >> 4 == 6 && packet.len() >= IPV6_MIN_HEADER_SIZE => {
|
|
||||||
let len_bytes: [u8; IP_LEN_SZ] = packet[IPV6_LEN_OFF..IPV6_LEN_OFF + IP_LEN_SZ]
|
|
||||||
.try_into()
|
|
||||||
.unwrap();
|
|
||||||
let addr_bytes: [u8; IPV6_IP_SZ] = packet
|
|
||||||
[IPV6_SRC_IP_OFF..IPV6_SRC_IP_OFF + IPV6_IP_SZ]
|
|
||||||
.try_into()
|
|
||||||
.unwrap();
|
|
||||||
(
|
|
||||||
u16::from_be_bytes(len_bytes) as usize + IPV6_MIN_HEADER_SIZE,
|
|
||||||
IpAddr::from(addr_bytes),
|
|
||||||
)
|
|
||||||
}
|
|
||||||
_ => return TunnResult::Err(WireGuardError::InvalidPacket),
|
|
||||||
};
|
|
||||||
|
|
||||||
if computed_len > packet.len() {
|
|
||||||
return TunnResult::Err(WireGuardError::InvalidPacket);
|
|
||||||
}
|
|
||||||
|
|
||||||
self.timer_tick(TimerName::TimeLastDataPacketReceived);
|
|
||||||
self.rx_bytes += computed_len;
|
|
||||||
|
|
||||||
match src_ip_address {
|
|
||||||
IpAddr::V4(addr) => TunnResult::WriteToTunnelV4(&mut packet[..computed_len], addr),
|
|
||||||
IpAddr::V6(addr) => TunnResult::WriteToTunnelV6(&mut packet[..computed_len], addr),
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
/// Get a packet from the queue, and try to encapsulate it
|
|
||||||
fn send_queued_packet<'a>(&mut self, dst: &'a mut [u8]) -> TunnResult<'a> {
|
|
||||||
if let Some(packet) = self.dequeue_packet() {
|
|
||||||
match self.encapsulate(&packet, dst) {
|
|
||||||
TunnResult::Err(_) => {
|
|
||||||
// On error, return packet to the queue
|
|
||||||
self.requeue_packet(packet);
|
|
||||||
}
|
|
||||||
r => return r,
|
|
||||||
}
|
|
||||||
}
|
|
||||||
TunnResult::Done
|
|
||||||
}
|
|
||||||
|
|
||||||
/// Push packet to the back of the queue
|
|
||||||
fn queue_packet(&mut self, packet: &[u8]) {
|
|
||||||
if self.packet_queue.len() < MAX_QUEUE_DEPTH {
|
|
||||||
// Drop if too many are already in queue
|
|
||||||
self.packet_queue.push_back(packet.to_vec());
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
/// Push packet to the front of the queue
|
|
||||||
fn requeue_packet(&mut self, packet: Vec<u8>) {
|
|
||||||
if self.packet_queue.len() < MAX_QUEUE_DEPTH {
|
|
||||||
// Drop if too many are already in queue
|
|
||||||
self.packet_queue.push_front(packet);
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
fn dequeue_packet(&mut self) -> Option<Vec<u8>> {
|
|
||||||
self.packet_queue.pop_front()
|
|
||||||
}
|
|
||||||
|
|
||||||
fn estimate_loss(&self) -> f32 {
|
|
||||||
let session_idx = self.current;
|
|
||||||
|
|
||||||
let mut weight = 9.0;
|
|
||||||
let mut cur_avg = 0.0;
|
|
||||||
let mut total_weight = 0.0;
|
|
||||||
|
|
||||||
for i in 0..N_SESSIONS {
|
|
||||||
if let Some(ref session) = self.sessions[(session_idx.wrapping_sub(i)) % N_SESSIONS] {
|
|
||||||
let (expected, received) = session.current_packet_cnt();
|
|
||||||
|
|
||||||
let loss = if expected == 0 {
|
|
||||||
0.0
|
|
||||||
} else {
|
|
||||||
1.0 - received as f32 / expected as f32
|
|
||||||
};
|
|
||||||
|
|
||||||
cur_avg += loss * weight;
|
|
||||||
total_weight += weight;
|
|
||||||
weight /= 3.0;
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
if total_weight == 0.0 {
|
|
||||||
0.0
|
|
||||||
} else {
|
|
||||||
cur_avg / total_weight
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
/// Return stats from the tunnel:
|
|
||||||
/// * Time since last handshake in seconds
|
|
||||||
/// * Data bytes sent
|
|
||||||
/// * Data bytes received
|
|
||||||
pub fn stats(&self) -> (Option<Duration>, usize, usize, f32, Option<u32>) {
|
|
||||||
let time = self.time_since_last_handshake();
|
|
||||||
let tx_bytes = self.tx_bytes;
|
|
||||||
let rx_bytes = self.rx_bytes;
|
|
||||||
let loss = self.estimate_loss();
|
|
||||||
let rtt = self.handshake.last_rtt;
|
|
||||||
|
|
||||||
(time, tx_bytes, rx_bytes, loss, rtt)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
#[cfg(test)]
|
|
||||||
mod tests {
|
|
||||||
#[cfg(feature = "mock-instant")]
|
|
||||||
use crate::noise::timers::{REKEY_AFTER_TIME, REKEY_TIMEOUT};
|
|
||||||
|
|
||||||
use super::*;
|
|
||||||
use rand_core::{OsRng, RngCore};
|
|
||||||
|
|
||||||
fn create_two_tuns() -> (Tunn, Tunn) {
|
|
||||||
let my_secret_key = x25519_dalek::StaticSecret::random_from_rng(OsRng);
|
|
||||||
let my_public_key = x25519_dalek::PublicKey::from(&my_secret_key);
|
|
||||||
let my_idx = OsRng.next_u32();
|
|
||||||
|
|
||||||
let their_secret_key = x25519_dalek::StaticSecret::random_from_rng(OsRng);
|
|
||||||
let their_public_key = x25519_dalek::PublicKey::from(&their_secret_key);
|
|
||||||
let their_idx = OsRng.next_u32();
|
|
||||||
|
|
||||||
let my_tun = Tunn::new(my_secret_key, their_public_key, None, None, my_idx, None).unwrap();
|
|
||||||
|
|
||||||
let their_tun =
|
|
||||||
Tunn::new(their_secret_key, my_public_key, None, None, their_idx, None).unwrap();
|
|
||||||
|
|
||||||
(my_tun, their_tun)
|
|
||||||
}
|
|
||||||
|
|
||||||
fn create_handshake_init(tun: &mut Tunn) -> Vec<u8> {
|
|
||||||
let mut dst = vec![0u8; 2048];
|
|
||||||
let handshake_init = tun.format_handshake_initiation(&mut dst, false);
|
|
||||||
assert!(matches!(handshake_init, TunnResult::WriteToNetwork(_)));
|
|
||||||
let handshake_init = if let TunnResult::WriteToNetwork(sent) = handshake_init {
|
|
||||||
sent
|
|
||||||
} else {
|
|
||||||
unreachable!();
|
|
||||||
};
|
|
||||||
|
|
||||||
handshake_init.into()
|
|
||||||
}
|
|
||||||
|
|
||||||
fn create_handshake_response(tun: &mut Tunn, handshake_init: &[u8]) -> Vec<u8> {
|
|
||||||
let mut dst = vec![0u8; 2048];
|
|
||||||
let handshake_resp = tun.decapsulate(None, handshake_init, &mut dst);
|
|
||||||
assert!(matches!(handshake_resp, TunnResult::WriteToNetwork(_)));
|
|
||||||
|
|
||||||
let handshake_resp = if let TunnResult::WriteToNetwork(sent) = handshake_resp {
|
|
||||||
sent
|
|
||||||
} else {
|
|
||||||
unreachable!();
|
|
||||||
};
|
|
||||||
|
|
||||||
handshake_resp.into()
|
|
||||||
}
|
|
||||||
|
|
||||||
fn parse_handshake_resp(tun: &mut Tunn, handshake_resp: &[u8]) -> Vec<u8> {
|
|
||||||
let mut dst = vec![0u8; 2048];
|
|
||||||
let keepalive = tun.decapsulate(None, handshake_resp, &mut dst);
|
|
||||||
assert!(matches!(keepalive, TunnResult::WriteToNetwork(_)));
|
|
||||||
|
|
||||||
let keepalive = if let TunnResult::WriteToNetwork(sent) = keepalive {
|
|
||||||
sent
|
|
||||||
} else {
|
|
||||||
unreachable!();
|
|
||||||
};
|
|
||||||
|
|
||||||
keepalive.into()
|
|
||||||
}
|
|
||||||
|
|
||||||
fn parse_keepalive(tun: &mut Tunn, keepalive: &[u8]) {
|
|
||||||
let mut dst = vec![0u8; 2048];
|
|
||||||
let keepalive = tun.decapsulate(None, keepalive, &mut dst);
|
|
||||||
assert!(matches!(keepalive, TunnResult::Done));
|
|
||||||
}
|
|
||||||
|
|
||||||
fn create_two_tuns_and_handshake() -> (Tunn, Tunn) {
|
|
||||||
let (mut my_tun, mut their_tun) = create_two_tuns();
|
|
||||||
let init = create_handshake_init(&mut my_tun);
|
|
||||||
let resp = create_handshake_response(&mut their_tun, &init);
|
|
||||||
let keepalive = parse_handshake_resp(&mut my_tun, &resp);
|
|
||||||
parse_keepalive(&mut their_tun, &keepalive);
|
|
||||||
|
|
||||||
(my_tun, their_tun)
|
|
||||||
}
|
|
||||||
|
|
||||||
fn create_ipv4_udp_packet() -> Vec<u8> {
|
|
||||||
let header =
|
|
||||||
etherparse::PacketBuilder::ipv4([192, 168, 1, 2], [192, 168, 1, 3], 5).udp(5678, 23);
|
|
||||||
let payload = [0, 1, 2, 3];
|
|
||||||
let mut packet = Vec::<u8>::with_capacity(header.size(payload.len()));
|
|
||||||
header.write(&mut packet, &payload).unwrap();
|
|
||||||
packet
|
|
||||||
}
|
|
||||||
|
|
||||||
#[cfg(feature = "mock-instant")]
|
|
||||||
fn update_timer_results_in_handshake(tun: &mut Tunn) {
|
|
||||||
let mut dst = vec![0u8; 2048];
|
|
||||||
let result = tun.update_timers(&mut dst);
|
|
||||||
assert!(matches!(result, TunnResult::WriteToNetwork(_)));
|
|
||||||
let packet_data = if let TunnResult::WriteToNetwork(data) = result {
|
|
||||||
data
|
|
||||||
} else {
|
|
||||||
unreachable!();
|
|
||||||
};
|
|
||||||
let packet = Tunn::parse_incoming_packet(packet_data).unwrap();
|
|
||||||
assert!(matches!(packet, Packet::HandshakeInit(_)));
|
|
||||||
}
|
|
||||||
|
|
||||||
#[test]
|
|
||||||
fn create_two_tunnels_linked_to_eachother() {
|
|
||||||
let (_my_tun, _their_tun) = create_two_tuns();
|
|
||||||
}
|
|
||||||
|
|
||||||
#[test]
|
|
||||||
fn handshake_init() {
|
|
||||||
let (mut my_tun, _their_tun) = create_two_tuns();
|
|
||||||
let init = create_handshake_init(&mut my_tun);
|
|
||||||
let packet = Tunn::parse_incoming_packet(&init).unwrap();
|
|
||||||
assert!(matches!(packet, Packet::HandshakeInit(_)));
|
|
||||||
}
|
|
||||||
|
|
||||||
#[test]
|
|
||||||
fn handshake_init_and_response() {
|
|
||||||
let (mut my_tun, mut their_tun) = create_two_tuns();
|
|
||||||
let init = create_handshake_init(&mut my_tun);
|
|
||||||
let resp = create_handshake_response(&mut their_tun, &init);
|
|
||||||
let packet = Tunn::parse_incoming_packet(&resp).unwrap();
|
|
||||||
assert!(matches!(packet, Packet::HandshakeResponse(_)));
|
|
||||||
}
|
|
||||||
|
|
||||||
#[test]
|
|
||||||
fn full_handshake() {
|
|
||||||
let (mut my_tun, mut their_tun) = create_two_tuns();
|
|
||||||
let init = create_handshake_init(&mut my_tun);
|
|
||||||
let resp = create_handshake_response(&mut their_tun, &init);
|
|
||||||
let keepalive = parse_handshake_resp(&mut my_tun, &resp);
|
|
||||||
let packet = Tunn::parse_incoming_packet(&keepalive).unwrap();
|
|
||||||
assert!(matches!(packet, Packet::PacketData(_)));
|
|
||||||
}
|
|
||||||
|
|
||||||
#[test]
|
|
||||||
fn full_handshake_plus_timers() {
|
|
||||||
let (mut my_tun, mut their_tun) = create_two_tuns_and_handshake();
|
|
||||||
// Time has not yet advanced so their is nothing to do
|
|
||||||
assert!(matches!(my_tun.update_timers(&mut []), TunnResult::Done));
|
|
||||||
assert!(matches!(their_tun.update_timers(&mut []), TunnResult::Done));
|
|
||||||
}
|
|
||||||
|
|
||||||
#[test]
|
|
||||||
#[cfg(feature = "mock-instant")]
|
|
||||||
fn new_handshake_after_two_mins() {
|
|
||||||
let (mut my_tun, mut their_tun) = create_two_tuns_and_handshake();
|
|
||||||
let mut my_dst = [0u8; 1024];
|
|
||||||
|
|
||||||
// Advance time 1 second and "send" 1 packet so that we send a handshake
|
|
||||||
// after the timeout
|
|
||||||
mock_instant::MockClock::advance(Duration::from_secs(1));
|
|
||||||
assert!(matches!(their_tun.update_timers(&mut []), TunnResult::Done));
|
|
||||||
assert!(matches!(
|
|
||||||
my_tun.update_timers(&mut my_dst),
|
|
||||||
TunnResult::Done
|
|
||||||
));
|
|
||||||
let sent_packet_buf = create_ipv4_udp_packet();
|
|
||||||
let data = my_tun.encapsulate(&sent_packet_buf, &mut my_dst);
|
|
||||||
assert!(matches!(data, TunnResult::WriteToNetwork(_)));
|
|
||||||
|
|
||||||
//Advance to timeout
|
|
||||||
mock_instant::MockClock::advance(REKEY_AFTER_TIME);
|
|
||||||
assert!(matches!(their_tun.update_timers(&mut []), TunnResult::Done));
|
|
||||||
update_timer_results_in_handshake(&mut my_tun);
|
|
||||||
}
|
|
||||||
|
|
||||||
#[test]
|
|
||||||
#[cfg(feature = "mock-instant")]
|
|
||||||
fn handshake_no_resp_rekey_timeout() {
|
|
||||||
let (mut my_tun, _their_tun) = create_two_tuns();
|
|
||||||
|
|
||||||
let init = create_handshake_init(&mut my_tun);
|
|
||||||
let packet = Tunn::parse_incoming_packet(&init).unwrap();
|
|
||||||
assert!(matches!(packet, Packet::HandshakeInit(_)));
|
|
||||||
|
|
||||||
mock_instant::MockClock::advance(REKEY_TIMEOUT);
|
|
||||||
update_timer_results_in_handshake(&mut my_tun)
|
|
||||||
}
|
|
||||||
|
|
||||||
#[test]
|
|
||||||
fn one_ip_packet() {
|
|
||||||
let (mut my_tun, mut their_tun) = create_two_tuns_and_handshake();
|
|
||||||
let mut my_dst = [0u8; 1024];
|
|
||||||
let mut their_dst = [0u8; 1024];
|
|
||||||
|
|
||||||
let sent_packet_buf = create_ipv4_udp_packet();
|
|
||||||
|
|
||||||
let data = my_tun.encapsulate(&sent_packet_buf, &mut my_dst);
|
|
||||||
assert!(matches!(data, TunnResult::WriteToNetwork(_)));
|
|
||||||
let data = if let TunnResult::WriteToNetwork(sent) = data {
|
|
||||||
sent
|
|
||||||
} else {
|
|
||||||
unreachable!();
|
|
||||||
};
|
|
||||||
|
|
||||||
let data = their_tun.decapsulate(None, data, &mut their_dst);
|
|
||||||
assert!(matches!(data, TunnResult::WriteToTunnelV4(..)));
|
|
||||||
let recv_packet_buf = if let TunnResult::WriteToTunnelV4(recv, _addr) = data {
|
|
||||||
recv
|
|
||||||
} else {
|
|
||||||
unreachable!();
|
|
||||||
};
|
|
||||||
assert_eq!(sent_packet_buf, recv_packet_buf);
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
@ -1,193 +0,0 @@
|
||||||
use super::handshake::{b2s_hash, b2s_keyed_mac_16, b2s_keyed_mac_16_2, b2s_mac_24};
|
|
||||||
use crate::noise::handshake::{LABEL_COOKIE, LABEL_MAC1};
|
|
||||||
use crate::noise::{HandshakeInit, HandshakeResponse, Packet, Tunn, TunnResult, WireGuardError};
|
|
||||||
|
|
||||||
#[cfg(feature = "mock-instant")]
|
|
||||||
use mock_instant::Instant;
|
|
||||||
use std::net::IpAddr;
|
|
||||||
use std::sync::atomic::{AtomicU64, Ordering};
|
|
||||||
|
|
||||||
#[cfg(not(feature = "mock-instant"))]
|
|
||||||
use crate::sleepyinstant::Instant;
|
|
||||||
|
|
||||||
use aead::generic_array::GenericArray;
|
|
||||||
use aead::{AeadInPlace, KeyInit};
|
|
||||||
use chacha20poly1305::{Key, XChaCha20Poly1305};
|
|
||||||
use parking_lot::Mutex;
|
|
||||||
use rand_core::{OsRng, RngCore};
|
|
||||||
use ring::constant_time::verify_slices_are_equal;
|
|
||||||
|
|
||||||
const COOKIE_REFRESH: u64 = 128; // Use 128 and not 120 so the compiler can optimize out the division
|
|
||||||
const COOKIE_SIZE: usize = 16;
|
|
||||||
const COOKIE_NONCE_SIZE: usize = 24;
|
|
||||||
|
|
||||||
/// How often should reset count in seconds
|
|
||||||
const RESET_PERIOD: u64 = 1;
|
|
||||||
|
|
||||||
type Cookie = [u8; COOKIE_SIZE];
|
|
||||||
|
|
||||||
/// There are two places where WireGuard requires "randomness" for cookies
|
|
||||||
/// * The 24 byte nonce in the cookie massage - here the only goal is to avoid nonce reuse
|
|
||||||
/// * A secret value that changes every two minutes
|
|
||||||
/// Because the main goal of the cookie is simply for a party to prove ownership of an IP address
|
|
||||||
/// we can relax the randomness definition a bit, in order to avoid locking, because using less
|
|
||||||
/// resources is the main goal of any DoS prevention mechanism.
|
|
||||||
/// In order to avoid locking and calls to rand we derive pseudo random values using the AEAD and
|
|
||||||
/// some counters.
|
|
||||||
pub struct RateLimiter {
|
|
||||||
/// The key we use to derive the nonce
|
|
||||||
nonce_key: [u8; 32],
|
|
||||||
/// The key we use to derive the cookie
|
|
||||||
secret_key: [u8; 16],
|
|
||||||
start_time: Instant,
|
|
||||||
/// A single 64 bit counter (should suffice for many years)
|
|
||||||
nonce_ctr: AtomicU64,
|
|
||||||
mac1_key: [u8; 32],
|
|
||||||
cookie_key: Key,
|
|
||||||
limit: u64,
|
|
||||||
/// The counter since last reset
|
|
||||||
count: AtomicU64,
|
|
||||||
/// The time last reset was performed on this rate limiter
|
|
||||||
last_reset: Mutex<Instant>,
|
|
||||||
}
|
|
||||||
|
|
||||||
impl RateLimiter {
|
|
||||||
pub fn new(public_key: &crate::x25519::PublicKey, limit: u64) -> Self {
|
|
||||||
let mut secret_key = [0u8; 16];
|
|
||||||
OsRng.fill_bytes(&mut secret_key);
|
|
||||||
RateLimiter {
|
|
||||||
nonce_key: Self::rand_bytes(),
|
|
||||||
secret_key,
|
|
||||||
start_time: Instant::now(),
|
|
||||||
nonce_ctr: AtomicU64::new(0),
|
|
||||||
mac1_key: b2s_hash(LABEL_MAC1, public_key.as_bytes()),
|
|
||||||
cookie_key: b2s_hash(LABEL_COOKIE, public_key.as_bytes()).into(),
|
|
||||||
limit,
|
|
||||||
count: AtomicU64::new(0),
|
|
||||||
last_reset: Mutex::new(Instant::now()),
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
fn rand_bytes() -> [u8; 32] {
|
|
||||||
let mut key = [0u8; 32];
|
|
||||||
OsRng.fill_bytes(&mut key);
|
|
||||||
key
|
|
||||||
}
|
|
||||||
|
|
||||||
/// Reset packet count (ideally should be called with a period of 1 second)
|
|
||||||
pub fn reset_count(&self) {
|
|
||||||
// The rate limiter is not very accurate, but at the scale we care about it doesn't matter much
|
|
||||||
let current_time = Instant::now();
|
|
||||||
let mut last_reset_time = self.last_reset.lock();
|
|
||||||
if current_time.duration_since(*last_reset_time).as_secs() >= RESET_PERIOD {
|
|
||||||
self.count.store(0, Ordering::SeqCst);
|
|
||||||
*last_reset_time = current_time;
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
/// Compute the correct cookie value based on the current secret value and the source IP
|
|
||||||
fn current_cookie(&self, addr: IpAddr) -> Cookie {
|
|
||||||
let mut addr_bytes = [0u8; 16];
|
|
||||||
|
|
||||||
match addr {
|
|
||||||
IpAddr::V4(a) => addr_bytes[..4].copy_from_slice(&a.octets()[..]),
|
|
||||||
IpAddr::V6(a) => addr_bytes[..].copy_from_slice(&a.octets()[..]),
|
|
||||||
}
|
|
||||||
|
|
||||||
// The current cookie for a given IP is the MAC(responder.changing_secret_every_two_minutes, initiator.ip_address)
|
|
||||||
// First we derive the secret from the current time, the value of cur_counter would change with time.
|
|
||||||
let cur_counter = Instant::now().duration_since(self.start_time).as_secs() / COOKIE_REFRESH;
|
|
||||||
|
|
||||||
// Next we derive the cookie
|
|
||||||
b2s_keyed_mac_16_2(&self.secret_key, &cur_counter.to_le_bytes(), &addr_bytes)
|
|
||||||
}
|
|
||||||
|
|
||||||
fn nonce(&self) -> [u8; COOKIE_NONCE_SIZE] {
|
|
||||||
let ctr = self.nonce_ctr.fetch_add(1, Ordering::Relaxed);
|
|
||||||
|
|
||||||
b2s_mac_24(&self.nonce_key, &ctr.to_le_bytes())
|
|
||||||
}
|
|
||||||
|
|
||||||
fn is_under_load(&self) -> bool {
|
|
||||||
self.count.fetch_add(1, Ordering::SeqCst) >= self.limit
|
|
||||||
}
|
|
||||||
|
|
||||||
pub(crate) fn format_cookie_reply<'a>(
|
|
||||||
&self,
|
|
||||||
idx: u32,
|
|
||||||
cookie: Cookie,
|
|
||||||
mac1: &[u8],
|
|
||||||
dst: &'a mut [u8],
|
|
||||||
) -> Result<&'a mut [u8], WireGuardError> {
|
|
||||||
if dst.len() < super::COOKIE_REPLY_SZ {
|
|
||||||
return Err(WireGuardError::DestinationBufferTooSmall);
|
|
||||||
}
|
|
||||||
|
|
||||||
let (message_type, rest) = dst.split_at_mut(4);
|
|
||||||
let (receiver_index, rest) = rest.split_at_mut(4);
|
|
||||||
let (nonce, rest) = rest.split_at_mut(24);
|
|
||||||
let (encrypted_cookie, _) = rest.split_at_mut(16 + 16);
|
|
||||||
|
|
||||||
// msg.message_type = 3
|
|
||||||
// msg.reserved_zero = { 0, 0, 0 }
|
|
||||||
message_type.copy_from_slice(&super::COOKIE_REPLY.to_le_bytes());
|
|
||||||
// msg.receiver_index = little_endian(initiator.sender_index)
|
|
||||||
receiver_index.copy_from_slice(&idx.to_le_bytes());
|
|
||||||
nonce.copy_from_slice(&self.nonce()[..]);
|
|
||||||
|
|
||||||
let cipher = XChaCha20Poly1305::new(&self.cookie_key);
|
|
||||||
|
|
||||||
let iv = GenericArray::from_slice(nonce);
|
|
||||||
|
|
||||||
encrypted_cookie[..16].copy_from_slice(&cookie);
|
|
||||||
let tag = cipher
|
|
||||||
.encrypt_in_place_detached(iv, mac1, &mut encrypted_cookie[..16])
|
|
||||||
.map_err(|_| WireGuardError::DestinationBufferTooSmall)?;
|
|
||||||
|
|
||||||
encrypted_cookie[16..].copy_from_slice(&tag);
|
|
||||||
|
|
||||||
Ok(&mut dst[..super::COOKIE_REPLY_SZ])
|
|
||||||
}
|
|
||||||
|
|
||||||
/// Verify the MAC fields on the datagram, and apply rate limiting if needed
|
|
||||||
pub fn verify_packet<'a, 'b>(
|
|
||||||
&self,
|
|
||||||
src_addr: Option<IpAddr>,
|
|
||||||
src: &'a [u8],
|
|
||||||
dst: &'b mut [u8],
|
|
||||||
) -> Result<Packet<'a>, TunnResult<'b>> {
|
|
||||||
let packet = Tunn::parse_incoming_packet(src)?;
|
|
||||||
|
|
||||||
// Verify and rate limit handshake messages only
|
|
||||||
if let Packet::HandshakeInit(HandshakeInit { sender_idx, .. })
|
|
||||||
| Packet::HandshakeResponse(HandshakeResponse { sender_idx, .. }) = packet
|
|
||||||
{
|
|
||||||
let (msg, macs) = src.split_at(src.len() - 32);
|
|
||||||
let (mac1, mac2) = macs.split_at(16);
|
|
||||||
|
|
||||||
let computed_mac1 = b2s_keyed_mac_16(&self.mac1_key, msg);
|
|
||||||
verify_slices_are_equal(&computed_mac1[..16], mac1)
|
|
||||||
.map_err(|_| TunnResult::Err(WireGuardError::InvalidMac))?;
|
|
||||||
|
|
||||||
if self.is_under_load() {
|
|
||||||
let addr = match src_addr {
|
|
||||||
None => return Err(TunnResult::Err(WireGuardError::UnderLoad)),
|
|
||||||
Some(addr) => addr,
|
|
||||||
};
|
|
||||||
|
|
||||||
// Only given an address can we validate mac2
|
|
||||||
let cookie = self.current_cookie(addr);
|
|
||||||
let computed_mac2 = b2s_keyed_mac_16_2(&cookie, msg, mac1);
|
|
||||||
|
|
||||||
if verify_slices_are_equal(&computed_mac2[..16], mac2).is_err() {
|
|
||||||
let cookie_packet = self
|
|
||||||
.format_cookie_reply(sender_idx, cookie, mac1, dst)
|
|
||||||
.map_err(TunnResult::Err)?;
|
|
||||||
return Err(TunnResult::WriteToNetwork(cookie_packet));
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
Ok(packet)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
@ -1,329 +0,0 @@
|
||||||
// Copyright (c) 2019 Cloudflare, Inc. All rights reserved.
|
|
||||||
// SPDX-License-Identifier: BSD-3-Clause
|
|
||||||
|
|
||||||
use super::PacketData;
|
|
||||||
use crate::noise::errors::WireGuardError;
|
|
||||||
use parking_lot::Mutex;
|
|
||||||
use ring::aead::{Aad, LessSafeKey, Nonce, UnboundKey, CHACHA20_POLY1305};
|
|
||||||
use std::sync::atomic::{AtomicUsize, Ordering};
|
|
||||||
|
|
||||||
pub struct Session {
|
|
||||||
pub(crate) receiving_index: u32,
|
|
||||||
sending_index: u32,
|
|
||||||
receiver: LessSafeKey,
|
|
||||||
sender: LessSafeKey,
|
|
||||||
sending_key_counter: AtomicUsize,
|
|
||||||
receiving_key_counter: Mutex<ReceivingKeyCounterValidator>,
|
|
||||||
}
|
|
||||||
|
|
||||||
impl std::fmt::Debug for Session {
|
|
||||||
fn fmt(&self, f: &mut std::fmt::Formatter) -> std::fmt::Result {
|
|
||||||
write!(
|
|
||||||
f,
|
|
||||||
"Session: {}<- ->{}",
|
|
||||||
self.receiving_index, self.sending_index
|
|
||||||
)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
/// Where encrypted data resides in a data packet
|
|
||||||
const DATA_OFFSET: usize = 16;
|
|
||||||
/// The overhead of the AEAD
|
|
||||||
const AEAD_SIZE: usize = 16;
|
|
||||||
|
|
||||||
// Receiving buffer constants
|
|
||||||
const WORD_SIZE: u64 = 64;
|
|
||||||
const N_WORDS: u64 = 16; // Suffice to reorder 64*16 = 1024 packets; can be increased at will
|
|
||||||
const N_BITS: u64 = WORD_SIZE * N_WORDS;
|
|
||||||
|
|
||||||
#[derive(Debug, Clone, Default)]
|
|
||||||
struct ReceivingKeyCounterValidator {
|
|
||||||
/// In order to avoid replays while allowing for some reordering of the packets, we keep a
|
|
||||||
/// bitmap of received packets, and the value of the highest counter
|
|
||||||
next: u64,
|
|
||||||
/// Used to estimate packet loss
|
|
||||||
receive_cnt: u64,
|
|
||||||
bitmap: [u64; N_WORDS as usize],
|
|
||||||
}
|
|
||||||
|
|
||||||
impl ReceivingKeyCounterValidator {
|
|
||||||
#[inline(always)]
|
|
||||||
fn set_bit(&mut self, idx: u64) {
|
|
||||||
let bit_idx = idx % N_BITS;
|
|
||||||
let word = (bit_idx / WORD_SIZE) as usize;
|
|
||||||
let bit = (bit_idx % WORD_SIZE) as usize;
|
|
||||||
self.bitmap[word] |= 1 << bit;
|
|
||||||
}
|
|
||||||
|
|
||||||
#[inline(always)]
|
|
||||||
fn clear_bit(&mut self, idx: u64) {
|
|
||||||
let bit_idx = idx % N_BITS;
|
|
||||||
let word = (bit_idx / WORD_SIZE) as usize;
|
|
||||||
let bit = (bit_idx % WORD_SIZE) as usize;
|
|
||||||
self.bitmap[word] &= !(1u64 << bit);
|
|
||||||
}
|
|
||||||
|
|
||||||
/// Clear the word that contains idx
|
|
||||||
#[inline(always)]
|
|
||||||
fn clear_word(&mut self, idx: u64) {
|
|
||||||
let bit_idx = idx % N_BITS;
|
|
||||||
let word = (bit_idx / WORD_SIZE) as usize;
|
|
||||||
self.bitmap[word] = 0;
|
|
||||||
}
|
|
||||||
|
|
||||||
/// Returns true if bit is set, false otherwise
|
|
||||||
#[inline(always)]
|
|
||||||
fn check_bit(&self, idx: u64) -> bool {
|
|
||||||
let bit_idx = idx % N_BITS;
|
|
||||||
let word = (bit_idx / WORD_SIZE) as usize;
|
|
||||||
let bit = (bit_idx % WORD_SIZE) as usize;
|
|
||||||
((self.bitmap[word] >> bit) & 1) == 1
|
|
||||||
}
|
|
||||||
|
|
||||||
/// Returns true if the counter was not yet received, and is not too far back
|
|
||||||
#[inline(always)]
|
|
||||||
fn will_accept(&self, counter: u64) -> Result<(), WireGuardError> {
|
|
||||||
if counter >= self.next {
|
|
||||||
// As long as the counter is growing no replay took place for sure
|
|
||||||
return Ok(());
|
|
||||||
}
|
|
||||||
if counter + N_BITS < self.next {
|
|
||||||
// Drop if too far back
|
|
||||||
return Err(WireGuardError::InvalidCounter);
|
|
||||||
}
|
|
||||||
if !self.check_bit(counter) {
|
|
||||||
Ok(())
|
|
||||||
} else {
|
|
||||||
Err(WireGuardError::DuplicateCounter)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
/// Marks the counter as received, and returns true if it is still good (in case during
|
|
||||||
/// decryption something changed)
|
|
||||||
#[inline(always)]
|
|
||||||
fn mark_did_receive(&mut self, counter: u64) -> Result<(), WireGuardError> {
|
|
||||||
if counter + N_BITS < self.next {
|
|
||||||
// Drop if too far back
|
|
||||||
return Err(WireGuardError::InvalidCounter);
|
|
||||||
}
|
|
||||||
if counter == self.next {
|
|
||||||
// Usually the packets arrive in order, in that case we simply mark the bit and
|
|
||||||
// increment the counter
|
|
||||||
self.set_bit(counter);
|
|
||||||
self.next += 1;
|
|
||||||
return Ok(());
|
|
||||||
}
|
|
||||||
if counter < self.next {
|
|
||||||
// A packet arrived out of order, check if it is valid, and mark
|
|
||||||
if self.check_bit(counter) {
|
|
||||||
return Err(WireGuardError::InvalidCounter);
|
|
||||||
}
|
|
||||||
self.set_bit(counter);
|
|
||||||
return Ok(());
|
|
||||||
}
|
|
||||||
// Packets where dropped, or maybe reordered, skip them and mark unused
|
|
||||||
if counter - self.next >= N_BITS {
|
|
||||||
// Too far ahead, clear all the bits
|
|
||||||
for c in self.bitmap.iter_mut() {
|
|
||||||
*c = 0;
|
|
||||||
}
|
|
||||||
} else {
|
|
||||||
let mut i = self.next;
|
|
||||||
while i % WORD_SIZE != 0 && i < counter {
|
|
||||||
// Clear until i aligned to word size
|
|
||||||
self.clear_bit(i);
|
|
||||||
i += 1;
|
|
||||||
}
|
|
||||||
while i + WORD_SIZE < counter {
|
|
||||||
// Clear whole word at a time
|
|
||||||
self.clear_word(i);
|
|
||||||
i = (i + WORD_SIZE) & 0u64.wrapping_sub(WORD_SIZE);
|
|
||||||
}
|
|
||||||
while i < counter {
|
|
||||||
// Clear any remaining bits
|
|
||||||
self.clear_bit(i);
|
|
||||||
i += 1;
|
|
||||||
}
|
|
||||||
}
|
|
||||||
self.set_bit(counter);
|
|
||||||
self.next = counter + 1;
|
|
||||||
Ok(())
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
impl Session {
|
|
||||||
pub(super) fn new(
|
|
||||||
local_index: u32,
|
|
||||||
peer_index: u32,
|
|
||||||
receiving_key: [u8; 32],
|
|
||||||
sending_key: [u8; 32],
|
|
||||||
) -> Session {
|
|
||||||
Session {
|
|
||||||
receiving_index: local_index,
|
|
||||||
sending_index: peer_index,
|
|
||||||
receiver: LessSafeKey::new(
|
|
||||||
UnboundKey::new(&CHACHA20_POLY1305, &receiving_key).unwrap(),
|
|
||||||
),
|
|
||||||
sender: LessSafeKey::new(UnboundKey::new(&CHACHA20_POLY1305, &sending_key).unwrap()),
|
|
||||||
sending_key_counter: AtomicUsize::new(0),
|
|
||||||
receiving_key_counter: Mutex::new(Default::default()),
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
pub(super) fn local_index(&self) -> usize {
|
|
||||||
self.receiving_index as usize
|
|
||||||
}
|
|
||||||
|
|
||||||
/// Returns true if receiving counter is good to use
|
|
||||||
fn receiving_counter_quick_check(&self, counter: u64) -> Result<(), WireGuardError> {
|
|
||||||
let counter_validator = self.receiving_key_counter.lock();
|
|
||||||
counter_validator.will_accept(counter)
|
|
||||||
}
|
|
||||||
|
|
||||||
/// Returns true if receiving counter is good to use, and marks it as used {
|
|
||||||
fn receiving_counter_mark(&self, counter: u64) -> Result<(), WireGuardError> {
|
|
||||||
let mut counter_validator = self.receiving_key_counter.lock();
|
|
||||||
let ret = counter_validator.mark_did_receive(counter);
|
|
||||||
if ret.is_ok() {
|
|
||||||
counter_validator.receive_cnt += 1;
|
|
||||||
}
|
|
||||||
ret
|
|
||||||
}
|
|
||||||
|
|
||||||
/// src - an IP packet from the interface
|
|
||||||
/// dst - pre-allocated space to hold the encapsulating UDP packet to send over the network
|
|
||||||
/// returns the size of the formatted packet
|
|
||||||
pub(super) fn format_packet_data<'a>(&self, src: &[u8], dst: &'a mut [u8]) -> &'a mut [u8] {
|
|
||||||
if dst.len() < src.len() + super::DATA_OVERHEAD_SZ {
|
|
||||||
panic!("The destination buffer is too small");
|
|
||||||
}
|
|
||||||
|
|
||||||
let sending_key_counter = self.sending_key_counter.fetch_add(1, Ordering::Relaxed) as u64;
|
|
||||||
|
|
||||||
let (message_type, rest) = dst.split_at_mut(4);
|
|
||||||
let (receiver_index, rest) = rest.split_at_mut(4);
|
|
||||||
let (counter, data) = rest.split_at_mut(8);
|
|
||||||
|
|
||||||
message_type.copy_from_slice(&super::DATA.to_le_bytes());
|
|
||||||
receiver_index.copy_from_slice(&self.sending_index.to_le_bytes());
|
|
||||||
counter.copy_from_slice(&sending_key_counter.to_le_bytes());
|
|
||||||
|
|
||||||
// TODO: spec requires padding to 16 bytes, but actually works fine without it
|
|
||||||
let n = {
|
|
||||||
let mut nonce = [0u8; 12];
|
|
||||||
nonce[4..12].copy_from_slice(&sending_key_counter.to_le_bytes());
|
|
||||||
data[..src.len()].copy_from_slice(src);
|
|
||||||
self.sender
|
|
||||||
.seal_in_place_separate_tag(
|
|
||||||
Nonce::assume_unique_for_key(nonce),
|
|
||||||
Aad::from(&[]),
|
|
||||||
&mut data[..src.len()],
|
|
||||||
)
|
|
||||||
.map(|tag| {
|
|
||||||
data[src.len()..src.len() + AEAD_SIZE].copy_from_slice(tag.as_ref());
|
|
||||||
src.len() + AEAD_SIZE
|
|
||||||
})
|
|
||||||
.unwrap()
|
|
||||||
};
|
|
||||||
|
|
||||||
&mut dst[..DATA_OFFSET + n]
|
|
||||||
}
|
|
||||||
|
|
||||||
/// packet - a data packet we received from the network
|
|
||||||
/// dst - pre-allocated space to hold the encapsulated IP packet, to send to the interface
|
|
||||||
/// dst will always take less space than src
|
|
||||||
/// return the size of the encapsulated packet on success
|
|
||||||
pub(super) fn receive_packet_data<'a>(
|
|
||||||
&self,
|
|
||||||
packet: PacketData,
|
|
||||||
dst: &'a mut [u8],
|
|
||||||
) -> Result<&'a mut [u8], WireGuardError> {
|
|
||||||
let ct_len = packet.encrypted_encapsulated_packet.len();
|
|
||||||
if dst.len() < ct_len {
|
|
||||||
// This is a very incorrect use of the library, therefore panic and not error
|
|
||||||
panic!("The destination buffer is too small");
|
|
||||||
}
|
|
||||||
if packet.receiver_idx != self.receiving_index {
|
|
||||||
return Err(WireGuardError::WrongIndex);
|
|
||||||
}
|
|
||||||
// Don't reuse counters, in case this is a replay attack we want to quickly check the counter without running expensive decryption
|
|
||||||
self.receiving_counter_quick_check(packet.counter)?;
|
|
||||||
|
|
||||||
let ret = {
|
|
||||||
let mut nonce = [0u8; 12];
|
|
||||||
nonce[4..12].copy_from_slice(&packet.counter.to_le_bytes());
|
|
||||||
dst[..ct_len].copy_from_slice(packet.encrypted_encapsulated_packet);
|
|
||||||
self.receiver
|
|
||||||
.open_in_place(
|
|
||||||
Nonce::assume_unique_for_key(nonce),
|
|
||||||
Aad::from(&[]),
|
|
||||||
&mut dst[..ct_len],
|
|
||||||
)
|
|
||||||
.map_err(|_| WireGuardError::InvalidAeadTag)?
|
|
||||||
};
|
|
||||||
|
|
||||||
// After decryption is done, check counter again, and mark as received
|
|
||||||
self.receiving_counter_mark(packet.counter)?;
|
|
||||||
Ok(ret)
|
|
||||||
}
|
|
||||||
|
|
||||||
/// Returns the estimated downstream packet loss for this session
|
|
||||||
pub(super) fn current_packet_cnt(&self) -> (u64, u64) {
|
|
||||||
let counter_validator = self.receiving_key_counter.lock();
|
|
||||||
(counter_validator.next, counter_validator.receive_cnt)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
#[cfg(test)]
|
|
||||||
mod tests {
|
|
||||||
use super::*;
|
|
||||||
#[test]
|
|
||||||
fn test_replay_counter() {
|
|
||||||
let mut c: ReceivingKeyCounterValidator = Default::default();
|
|
||||||
|
|
||||||
assert!(c.mark_did_receive(0).is_ok());
|
|
||||||
assert!(c.mark_did_receive(0).is_err());
|
|
||||||
assert!(c.mark_did_receive(1).is_ok());
|
|
||||||
assert!(c.mark_did_receive(1).is_err());
|
|
||||||
assert!(c.mark_did_receive(63).is_ok());
|
|
||||||
assert!(c.mark_did_receive(63).is_err());
|
|
||||||
assert!(c.mark_did_receive(15).is_ok());
|
|
||||||
assert!(c.mark_did_receive(15).is_err());
|
|
||||||
|
|
||||||
for i in 64..N_BITS + 128 {
|
|
||||||
assert!(c.mark_did_receive(i).is_ok());
|
|
||||||
assert!(c.mark_did_receive(i).is_err());
|
|
||||||
}
|
|
||||||
|
|
||||||
assert!(c.mark_did_receive(N_BITS * 3).is_ok());
|
|
||||||
for i in 0..=N_BITS * 2 {
|
|
||||||
assert!(matches!(
|
|
||||||
c.will_accept(i),
|
|
||||||
Err(WireGuardError::InvalidCounter)
|
|
||||||
));
|
|
||||||
assert!(c.mark_did_receive(i).is_err());
|
|
||||||
}
|
|
||||||
for i in N_BITS * 2 + 1..N_BITS * 3 {
|
|
||||||
assert!(c.will_accept(i).is_ok());
|
|
||||||
}
|
|
||||||
assert!(matches!(
|
|
||||||
c.will_accept(N_BITS * 3),
|
|
||||||
Err(WireGuardError::DuplicateCounter)
|
|
||||||
));
|
|
||||||
|
|
||||||
for i in (N_BITS * 2 + 1..N_BITS * 3).rev() {
|
|
||||||
assert!(c.mark_did_receive(i).is_ok());
|
|
||||||
assert!(c.mark_did_receive(i).is_err());
|
|
||||||
}
|
|
||||||
|
|
||||||
assert!(c.mark_did_receive(N_BITS * 3 + 70).is_ok());
|
|
||||||
assert!(c.mark_did_receive(N_BITS * 3 + 71).is_ok());
|
|
||||||
assert!(c.mark_did_receive(N_BITS * 3 + 72).is_ok());
|
|
||||||
assert!(c.mark_did_receive(N_BITS * 3 + 72 + 125).is_ok());
|
|
||||||
assert!(c.mark_did_receive(N_BITS * 3 + 63).is_ok());
|
|
||||||
|
|
||||||
assert!(c.mark_did_receive(N_BITS * 3 + 70).is_err());
|
|
||||||
assert!(c.mark_did_receive(N_BITS * 3 + 71).is_err());
|
|
||||||
assert!(c.mark_did_receive(N_BITS * 3 + 72).is_err());
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
@ -1,335 +0,0 @@
|
||||||
// Copyright (c) 2019 Cloudflare, Inc. All rights reserved.
|
|
||||||
// SPDX-License-Identifier: BSD-3-Clause
|
|
||||||
|
|
||||||
use super::errors::WireGuardError;
|
|
||||||
use crate::noise::{Tunn, TunnResult};
|
|
||||||
use std::mem;
|
|
||||||
use std::ops::{Index, IndexMut};
|
|
||||||
|
|
||||||
use std::time::Duration;
|
|
||||||
|
|
||||||
#[cfg(feature = "mock-instant")]
|
|
||||||
use mock_instant::Instant;
|
|
||||||
|
|
||||||
#[cfg(not(feature = "mock-instant"))]
|
|
||||||
use crate::sleepyinstant::Instant;
|
|
||||||
|
|
||||||
// Some constants, represent time in seconds
|
|
||||||
// https://www.wireguard.com/papers/wireguard.pdf#page=14
|
|
||||||
pub(crate) const REKEY_AFTER_TIME: Duration = Duration::from_secs(120);
|
|
||||||
const REJECT_AFTER_TIME: Duration = Duration::from_secs(180);
|
|
||||||
const REKEY_ATTEMPT_TIME: Duration = Duration::from_secs(90);
|
|
||||||
pub(crate) const REKEY_TIMEOUT: Duration = Duration::from_secs(5);
|
|
||||||
const KEEPALIVE_TIMEOUT: Duration = Duration::from_secs(10);
|
|
||||||
const COOKIE_EXPIRATION_TIME: Duration = Duration::from_secs(120);
|
|
||||||
|
|
||||||
#[derive(Debug)]
|
|
||||||
pub enum TimerName {
|
|
||||||
/// Current time, updated each call to `update_timers`
|
|
||||||
TimeCurrent,
|
|
||||||
/// Time when last handshake was completed
|
|
||||||
TimeSessionEstablished,
|
|
||||||
/// Time the last attempt for a new handshake began
|
|
||||||
TimeLastHandshakeStarted,
|
|
||||||
/// Time we last received and authenticated a packet
|
|
||||||
TimeLastPacketReceived,
|
|
||||||
/// Time we last send a packet
|
|
||||||
TimeLastPacketSent,
|
|
||||||
/// Time we last received and authenticated a DATA packet
|
|
||||||
TimeLastDataPacketReceived,
|
|
||||||
/// Time we last send a DATA packet
|
|
||||||
TimeLastDataPacketSent,
|
|
||||||
/// Time we last received a cookie
|
|
||||||
TimeCookieReceived,
|
|
||||||
/// Time we last sent persistent keepalive
|
|
||||||
TimePersistentKeepalive,
|
|
||||||
Top,
|
|
||||||
}
|
|
||||||
|
|
||||||
use self::TimerName::*;
|
|
||||||
|
|
||||||
#[derive(Debug)]
|
|
||||||
pub struct Timers {
|
|
||||||
/// Is the owner of the timer the initiator or the responder for the last handshake?
|
|
||||||
is_initiator: bool,
|
|
||||||
/// Start time of the tunnel
|
|
||||||
time_started: Instant,
|
|
||||||
timers: [Duration; TimerName::Top as usize],
|
|
||||||
pub(super) session_timers: [Duration; super::N_SESSIONS],
|
|
||||||
/// Did we receive data without sending anything back?
|
|
||||||
want_keepalive: bool,
|
|
||||||
/// Did we send data without hearing back?
|
|
||||||
want_handshake: bool,
|
|
||||||
persistent_keepalive: usize,
|
|
||||||
/// Should this timer call reset rr function (if not a shared rr instance)
|
|
||||||
pub(super) should_reset_rr: bool,
|
|
||||||
}
|
|
||||||
|
|
||||||
impl Timers {
|
|
||||||
pub(super) fn new(persistent_keepalive: Option<u16>, reset_rr: bool) -> Timers {
|
|
||||||
Timers {
|
|
||||||
is_initiator: false,
|
|
||||||
time_started: Instant::now(),
|
|
||||||
timers: Default::default(),
|
|
||||||
session_timers: Default::default(),
|
|
||||||
want_keepalive: Default::default(),
|
|
||||||
want_handshake: Default::default(),
|
|
||||||
persistent_keepalive: usize::from(persistent_keepalive.unwrap_or(0)),
|
|
||||||
should_reset_rr: reset_rr,
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
fn is_initiator(&self) -> bool {
|
|
||||||
self.is_initiator
|
|
||||||
}
|
|
||||||
|
|
||||||
// We don't really clear the timers, but we set them to the current time to
|
|
||||||
// so the reference time frame is the same
|
|
||||||
pub(super) fn clear(&mut self) {
|
|
||||||
let now = Instant::now().duration_since(self.time_started);
|
|
||||||
for t in &mut self.timers[..] {
|
|
||||||
*t = now;
|
|
||||||
}
|
|
||||||
self.want_handshake = false;
|
|
||||||
self.want_keepalive = false;
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
impl Index<TimerName> for Timers {
|
|
||||||
type Output = Duration;
|
|
||||||
fn index(&self, index: TimerName) -> &Duration {
|
|
||||||
&self.timers[index as usize]
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
impl IndexMut<TimerName> for Timers {
|
|
||||||
fn index_mut(&mut self, index: TimerName) -> &mut Duration {
|
|
||||||
&mut self.timers[index as usize]
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
impl Tunn {
|
|
||||||
pub(super) fn timer_tick(&mut self, timer_name: TimerName) {
|
|
||||||
match timer_name {
|
|
||||||
TimeLastPacketReceived => {
|
|
||||||
self.timers.want_keepalive = true;
|
|
||||||
self.timers.want_handshake = false;
|
|
||||||
}
|
|
||||||
TimeLastPacketSent => {
|
|
||||||
self.timers.want_handshake = true;
|
|
||||||
self.timers.want_keepalive = false;
|
|
||||||
}
|
|
||||||
_ => {}
|
|
||||||
}
|
|
||||||
|
|
||||||
let time = self.timers[TimeCurrent];
|
|
||||||
self.timers[timer_name] = time;
|
|
||||||
}
|
|
||||||
|
|
||||||
pub(super) fn timer_tick_session_established(
|
|
||||||
&mut self,
|
|
||||||
is_initiator: bool,
|
|
||||||
session_idx: usize,
|
|
||||||
) {
|
|
||||||
self.timer_tick(TimeSessionEstablished);
|
|
||||||
self.timers.session_timers[session_idx % crate::noise::N_SESSIONS] =
|
|
||||||
self.timers[TimeCurrent];
|
|
||||||
self.timers.is_initiator = is_initiator;
|
|
||||||
}
|
|
||||||
|
|
||||||
// We don't really clear the timers, but we set them to the current time to
|
|
||||||
// so the reference time frame is the same
|
|
||||||
fn clear_all(&mut self) {
|
|
||||||
for session in &mut self.sessions {
|
|
||||||
*session = None;
|
|
||||||
}
|
|
||||||
|
|
||||||
self.packet_queue.clear();
|
|
||||||
|
|
||||||
self.timers.clear();
|
|
||||||
}
|
|
||||||
|
|
||||||
fn update_session_timers(&mut self, time_now: Duration) {
|
|
||||||
let timers = &mut self.timers;
|
|
||||||
|
|
||||||
for (i, t) in timers.session_timers.iter_mut().enumerate() {
|
|
||||||
if time_now - *t > REJECT_AFTER_TIME {
|
|
||||||
if let Some(session) = self.sessions[i].take() {
|
|
||||||
tracing::debug!(
|
|
||||||
message = "SESSION_EXPIRED(REJECT_AFTER_TIME)",
|
|
||||||
session = session.receiving_index
|
|
||||||
);
|
|
||||||
}
|
|
||||||
*t = time_now;
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
pub fn update_timers<'a>(&mut self, dst: &'a mut [u8]) -> TunnResult<'a> {
|
|
||||||
let mut handshake_initiation_required = false;
|
|
||||||
let mut keepalive_required = false;
|
|
||||||
|
|
||||||
let time = Instant::now();
|
|
||||||
|
|
||||||
if self.timers.should_reset_rr {
|
|
||||||
self.rate_limiter.reset_count();
|
|
||||||
}
|
|
||||||
|
|
||||||
// All the times are counted from tunnel initiation, for efficiency our timers are rounded
|
|
||||||
// to a second, as there is no real benefit to having highly accurate timers.
|
|
||||||
let now = time.duration_since(self.timers.time_started);
|
|
||||||
self.timers[TimeCurrent] = now;
|
|
||||||
|
|
||||||
self.update_session_timers(now);
|
|
||||||
|
|
||||||
// Load timers only once:
|
|
||||||
let session_established = self.timers[TimeSessionEstablished];
|
|
||||||
let handshake_started = self.timers[TimeLastHandshakeStarted];
|
|
||||||
let aut_packet_received = self.timers[TimeLastPacketReceived];
|
|
||||||
let aut_packet_sent = self.timers[TimeLastPacketSent];
|
|
||||||
let data_packet_received = self.timers[TimeLastDataPacketReceived];
|
|
||||||
let data_packet_sent = self.timers[TimeLastDataPacketSent];
|
|
||||||
let persistent_keepalive = self.timers.persistent_keepalive;
|
|
||||||
|
|
||||||
{
|
|
||||||
if self.handshake.is_expired() {
|
|
||||||
return TunnResult::Err(WireGuardError::ConnectionExpired);
|
|
||||||
}
|
|
||||||
|
|
||||||
// Clear cookie after COOKIE_EXPIRATION_TIME
|
|
||||||
if self.handshake.has_cookie()
|
|
||||||
&& now - self.timers[TimeCookieReceived] >= COOKIE_EXPIRATION_TIME
|
|
||||||
{
|
|
||||||
self.handshake.clear_cookie();
|
|
||||||
}
|
|
||||||
|
|
||||||
// All ephemeral private keys and symmetric session keys are zeroed out after
|
|
||||||
// (REJECT_AFTER_TIME * 3) ms if no new keys have been exchanged.
|
|
||||||
if now - session_established >= REJECT_AFTER_TIME * 3 {
|
|
||||||
tracing::error!("CONNECTION_EXPIRED(REJECT_AFTER_TIME * 3)");
|
|
||||||
self.handshake.set_expired();
|
|
||||||
self.clear_all();
|
|
||||||
return TunnResult::Err(WireGuardError::ConnectionExpired);
|
|
||||||
}
|
|
||||||
|
|
||||||
if let Some(time_init_sent) = self.handshake.timer() {
|
|
||||||
// Handshake Initiation Retransmission
|
|
||||||
if now - handshake_started >= REKEY_ATTEMPT_TIME {
|
|
||||||
// After REKEY_ATTEMPT_TIME ms of trying to initiate a new handshake,
|
|
||||||
// the retries give up and cease, and clear all existing packets queued
|
|
||||||
// up to be sent. If a packet is explicitly queued up to be sent, then
|
|
||||||
// this timer is reset.
|
|
||||||
tracing::error!("CONNECTION_EXPIRED(REKEY_ATTEMPT_TIME)");
|
|
||||||
self.handshake.set_expired();
|
|
||||||
self.clear_all();
|
|
||||||
return TunnResult::Err(WireGuardError::ConnectionExpired);
|
|
||||||
}
|
|
||||||
|
|
||||||
if time_init_sent.elapsed() >= REKEY_TIMEOUT {
|
|
||||||
// We avoid using `time` here, because it can be earlier than `time_init_sent`.
|
|
||||||
// Once `checked_duration_since` is stable we can use that.
|
|
||||||
// A handshake initiation is retried after REKEY_TIMEOUT + jitter ms,
|
|
||||||
// if a response has not been received, where jitter is some random
|
|
||||||
// value between 0 and 333 ms.
|
|
||||||
tracing::warn!("HANDSHAKE(REKEY_TIMEOUT)");
|
|
||||||
handshake_initiation_required = true;
|
|
||||||
}
|
|
||||||
} else {
|
|
||||||
if self.timers.is_initiator() {
|
|
||||||
// After sending a packet, if the sender was the original initiator
|
|
||||||
// of the handshake and if the current session key is REKEY_AFTER_TIME
|
|
||||||
// ms old, we initiate a new handshake. If the sender was the original
|
|
||||||
// responder of the handshake, it does not re-initiate a new handshake
|
|
||||||
// after REKEY_AFTER_TIME ms like the original initiator does.
|
|
||||||
if session_established < data_packet_sent
|
|
||||||
&& now - session_established >= REKEY_AFTER_TIME
|
|
||||||
{
|
|
||||||
tracing::debug!("HANDSHAKE(REKEY_AFTER_TIME (on send))");
|
|
||||||
handshake_initiation_required = true;
|
|
||||||
}
|
|
||||||
|
|
||||||
// After receiving a packet, if the receiver was the original initiator
|
|
||||||
// of the handshake and if the current session key is REJECT_AFTER_TIME
|
|
||||||
// - KEEPALIVE_TIMEOUT - REKEY_TIMEOUT ms old, we initiate a new
|
|
||||||
// handshake.
|
|
||||||
if session_established < data_packet_received
|
|
||||||
&& now - session_established
|
|
||||||
>= REJECT_AFTER_TIME - KEEPALIVE_TIMEOUT - REKEY_TIMEOUT
|
|
||||||
{
|
|
||||||
tracing::warn!(
|
|
||||||
"HANDSHAKE(REJECT_AFTER_TIME - KEEPALIVE_TIMEOUT - \
|
|
||||||
REKEY_TIMEOUT \
|
|
||||||
(on receive))"
|
|
||||||
);
|
|
||||||
handshake_initiation_required = true;
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
// If we have sent a packet to a given peer but have not received a
|
|
||||||
// packet after from that peer for (KEEPALIVE + REKEY_TIMEOUT) ms,
|
|
||||||
// we initiate a new handshake.
|
|
||||||
if data_packet_sent > aut_packet_received
|
|
||||||
&& now - aut_packet_received >= KEEPALIVE_TIMEOUT + REKEY_TIMEOUT
|
|
||||||
&& mem::replace(&mut self.timers.want_handshake, false)
|
|
||||||
{
|
|
||||||
tracing::warn!("HANDSHAKE(KEEPALIVE + REKEY_TIMEOUT)");
|
|
||||||
handshake_initiation_required = true;
|
|
||||||
}
|
|
||||||
|
|
||||||
if !handshake_initiation_required {
|
|
||||||
// If a packet has been received from a given peer, but we have not sent one back
|
|
||||||
// to the given peer in KEEPALIVE ms, we send an empty packet.
|
|
||||||
if data_packet_received > aut_packet_sent
|
|
||||||
&& now - aut_packet_sent >= KEEPALIVE_TIMEOUT
|
|
||||||
&& mem::replace(&mut self.timers.want_keepalive, false)
|
|
||||||
{
|
|
||||||
tracing::debug!("KEEPALIVE(KEEPALIVE_TIMEOUT)");
|
|
||||||
keepalive_required = true;
|
|
||||||
}
|
|
||||||
|
|
||||||
// Persistent KEEPALIVE
|
|
||||||
if persistent_keepalive > 0
|
|
||||||
&& (now - self.timers[TimePersistentKeepalive]
|
|
||||||
>= Duration::from_secs(persistent_keepalive as _))
|
|
||||||
{
|
|
||||||
tracing::debug!("KEEPALIVE(PERSISTENT_KEEPALIVE)");
|
|
||||||
self.timer_tick(TimePersistentKeepalive);
|
|
||||||
keepalive_required = true;
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
if handshake_initiation_required {
|
|
||||||
return self.format_handshake_initiation(dst, true);
|
|
||||||
}
|
|
||||||
|
|
||||||
if keepalive_required {
|
|
||||||
return self.encapsulate(&[], dst);
|
|
||||||
}
|
|
||||||
|
|
||||||
TunnResult::Done
|
|
||||||
}
|
|
||||||
|
|
||||||
pub fn time_since_last_handshake(&self) -> Option<Duration> {
|
|
||||||
let current_session = self.current;
|
|
||||||
if self.sessions[current_session % super::N_SESSIONS].is_some() {
|
|
||||||
let duration_since_tun_start = Instant::now().duration_since(self.timers.time_started);
|
|
||||||
let duration_since_session_established = self.timers[TimeSessionEstablished];
|
|
||||||
|
|
||||||
Some(duration_since_tun_start - duration_since_session_established)
|
|
||||||
} else {
|
|
||||||
None
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
pub fn persistent_keepalive(&self) -> Option<u16> {
|
|
||||||
let keepalive = self.timers.persistent_keepalive;
|
|
||||||
|
|
||||||
if keepalive > 0 {
|
|
||||||
Some(keepalive as u16)
|
|
||||||
} else {
|
|
||||||
None
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
@ -1,33 +0,0 @@
|
||||||
pub(crate) struct KeyBytes(pub [u8; 32]);
|
|
||||||
|
|
||||||
impl std::str::FromStr for KeyBytes {
|
|
||||||
type Err = &'static str;
|
|
||||||
|
|
||||||
/// Can parse a secret key from a hex or base64 encoded string.
|
|
||||||
fn from_str(s: &str) -> Result<Self, Self::Err> {
|
|
||||||
let mut internal = [0u8; 32];
|
|
||||||
|
|
||||||
match s.len() {
|
|
||||||
64 => {
|
|
||||||
// Try to parse as hex
|
|
||||||
for i in 0..32 {
|
|
||||||
internal[i] = u8::from_str_radix(&s[i * 2..=i * 2 + 1], 16)
|
|
||||||
.map_err(|_| "Illegal character in key")?;
|
|
||||||
}
|
|
||||||
}
|
|
||||||
43 | 44 => {
|
|
||||||
// Try to parse as base64
|
|
||||||
if let Ok(decoded_key) = base64::decode(s) {
|
|
||||||
if decoded_key.len() == internal.len() {
|
|
||||||
internal[..].copy_from_slice(&decoded_key);
|
|
||||||
} else {
|
|
||||||
return Err("Illegal character in key");
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
_ => return Err("Illegal key size"),
|
|
||||||
}
|
|
||||||
|
|
||||||
Ok(KeyBytes(internal))
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
@ -1,13 +0,0 @@
|
||||||
use serde::{Deserialize, Serialize};
|
|
||||||
use tun::TunOptions;
|
|
||||||
|
|
||||||
#[derive(Debug, Clone, Serialize, Deserialize)]
|
|
||||||
pub enum DaemonCommand {
|
|
||||||
Start(DaemonStartOptions),
|
|
||||||
Stop,
|
|
||||||
}
|
|
||||||
|
|
||||||
#[derive(Debug, Clone, Default, Serialize, Deserialize)]
|
|
||||||
pub struct DaemonStartOptions {
|
|
||||||
pub(super) tun: TunOptions,
|
|
||||||
}
|
|
||||||
|
|
@ -1,40 +0,0 @@
|
||||||
use super::*;
|
|
||||||
|
|
||||||
pub struct DaemonInstance {
|
|
||||||
rx: mpsc::Receiver<DaemonCommand>,
|
|
||||||
tun_interface: Option<TunInterface>,
|
|
||||||
}
|
|
||||||
|
|
||||||
impl DaemonInstance {
|
|
||||||
pub fn new(rx: mpsc::Receiver<DaemonCommand>) -> Self {
|
|
||||||
Self {
|
|
||||||
rx,
|
|
||||||
tun_interface: None,
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
pub async fn run(&mut self) -> Result<()> {
|
|
||||||
while let Some(command) = self.rx.recv().await {
|
|
||||||
match command {
|
|
||||||
DaemonCommand::Start(options) => {
|
|
||||||
if self.tun_interface.is_none() {
|
|
||||||
self.tun_interface = Some(options.tun.open()?);
|
|
||||||
eprintln!("Daemon starting tun interface.");
|
|
||||||
} else {
|
|
||||||
eprintln!("Got start, but tun interface already up.");
|
|
||||||
}
|
|
||||||
}
|
|
||||||
DaemonCommand::Stop => {
|
|
||||||
if self.tun_interface.is_some() {
|
|
||||||
self.tun_interface = None;
|
|
||||||
eprintln!("Daemon stopping tun interface.");
|
|
||||||
} else {
|
|
||||||
eprintln!("Got stop, but tun interface is not up.")
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
Ok(())
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
@ -1,19 +0,0 @@
|
||||||
use super::*;
|
|
||||||
use tokio::sync::mpsc;
|
|
||||||
|
|
||||||
mod command;
|
|
||||||
mod instance;
|
|
||||||
mod net;
|
|
||||||
|
|
||||||
use instance::DaemonInstance;
|
|
||||||
use net::listen;
|
|
||||||
|
|
||||||
pub use command::{DaemonCommand, DaemonStartOptions};
|
|
||||||
pub use net::DaemonClient;
|
|
||||||
|
|
||||||
pub async fn daemon_main() -> Result<()> {
|
|
||||||
let (tx, rx) = mpsc::channel(2);
|
|
||||||
let mut inst = DaemonInstance::new(rx);
|
|
||||||
|
|
||||||
tokio::try_join!(inst.run(), listen(tx)).map(|_| ())
|
|
||||||
}
|
|
||||||
|
|
@ -1,29 +0,0 @@
|
||||||
use super::*;
|
|
||||||
use serde::{Deserialize, Serialize};
|
|
||||||
|
|
||||||
#[cfg(target_family = "unix")]
|
|
||||||
mod unix;
|
|
||||||
#[cfg(all(target_family = "unix", not(target_os = "linux")))]
|
|
||||||
pub use unix::{listen, DaemonClient};
|
|
||||||
|
|
||||||
#[cfg(target_os = "linux")]
|
|
||||||
mod systemd;
|
|
||||||
#[cfg(target_os = "linux")]
|
|
||||||
pub use systemd::{listen, DaemonClient};
|
|
||||||
|
|
||||||
#[cfg(target_os = "windows")]
|
|
||||||
mod windows;
|
|
||||||
#[cfg(target_os = "windows")]
|
|
||||||
pub use windows::{listen, DaemonClient};
|
|
||||||
|
|
||||||
#[derive(Clone, Serialize, Deserialize)]
|
|
||||||
pub struct DaemonRequest {
|
|
||||||
pub id: u32,
|
|
||||||
pub command: DaemonCommand,
|
|
||||||
}
|
|
||||||
|
|
||||||
#[derive(Clone, Serialize, Deserialize)]
|
|
||||||
pub struct DaemonResponse {
|
|
||||||
// Error types can't be serialized, so this is the second best option.
|
|
||||||
result: std::result::Result<(), String>,
|
|
||||||
}
|
|
||||||
|
|
@ -1,16 +0,0 @@
|
||||||
use super::*;
|
|
||||||
use std::os::fd::IntoRawFd;
|
|
||||||
|
|
||||||
pub async fn listen(cmd_tx: mpsc::Sender<DaemonCommand>) -> Result<()> {
|
|
||||||
if !libsystemd::daemon::booted() || listen_with_systemd(cmd_tx.clone()).await.is_err() {
|
|
||||||
unix::listen(cmd_tx).await?;
|
|
||||||
}
|
|
||||||
Ok(())
|
|
||||||
}
|
|
||||||
|
|
||||||
async fn listen_with_systemd(cmd_tx: mpsc::Sender<DaemonCommand>) -> Result<()> {
|
|
||||||
let fds = libsystemd::activation::receive_descriptors(false).unwrap();
|
|
||||||
super::unix::listen_with_optional_fd(cmd_tx, Some(fds[0].clone().into_raw_fd())).await
|
|
||||||
}
|
|
||||||
|
|
||||||
pub type DaemonClient = unix::DaemonClient;
|
|
||||||
|
|
@ -1,102 +0,0 @@
|
||||||
use super::*;
|
|
||||||
use std::{
|
|
||||||
os::fd::{FromRawFd, RawFd},
|
|
||||||
os::unix::net::UnixListener as StdUnixListener,
|
|
||||||
path::Path,
|
|
||||||
};
|
|
||||||
use tokio::{
|
|
||||||
io::{AsyncBufReadExt, AsyncWriteExt, BufReader},
|
|
||||||
net::{UnixListener, UnixStream},
|
|
||||||
};
|
|
||||||
|
|
||||||
const UNIX_SOCKET_PATH: &str = "/run/burrow.sock";
|
|
||||||
|
|
||||||
pub async fn listen(cmd_tx: mpsc::Sender<DaemonCommand>) -> Result<()> {
|
|
||||||
listen_with_optional_fd(cmd_tx, None).await
|
|
||||||
}
|
|
||||||
|
|
||||||
pub(crate) async fn listen_with_optional_fd(
|
|
||||||
cmd_tx: mpsc::Sender<DaemonCommand>,
|
|
||||||
raw_fd: Option<RawFd>,
|
|
||||||
) -> Result<()> {
|
|
||||||
let path = Path::new(UNIX_SOCKET_PATH);
|
|
||||||
|
|
||||||
let listener = if let Some(raw_fd) = raw_fd {
|
|
||||||
let listener = unsafe { StdUnixListener::from_raw_fd(raw_fd) };
|
|
||||||
listener.set_nonblocking(true)?;
|
|
||||||
UnixListener::from_std(listener)
|
|
||||||
} else {
|
|
||||||
UnixListener::bind(path)
|
|
||||||
};
|
|
||||||
let listener = if let Ok(listener) = listener {
|
|
||||||
listener
|
|
||||||
} else {
|
|
||||||
// Won't help all that much, if we use the async version of fs.
|
|
||||||
std::fs::remove_file(path)?;
|
|
||||||
UnixListener::bind(path)?
|
|
||||||
};
|
|
||||||
loop {
|
|
||||||
let (stream, _) = listener.accept().await?;
|
|
||||||
let cmd_tx = cmd_tx.clone();
|
|
||||||
|
|
||||||
// I'm pretty sure we won't need to manually join / shut this down,
|
|
||||||
// `lines` will return Err during dropping, and this task should exit gracefully.
|
|
||||||
tokio::task::spawn(async {
|
|
||||||
let cmd_tx = cmd_tx;
|
|
||||||
let mut stream = stream;
|
|
||||||
let (mut read_stream, mut write_stream) = stream.split();
|
|
||||||
let buf_reader = BufReader::new(&mut read_stream);
|
|
||||||
let mut lines = buf_reader.lines();
|
|
||||||
while let Ok(Some(line)) = lines.next_line().await {
|
|
||||||
let mut res = DaemonResponse { result: Ok(()) };
|
|
||||||
let command = match serde_json::from_str::<DaemonRequest>(&line) {
|
|
||||||
Ok(req) => Some(req.command),
|
|
||||||
Err(e) => {
|
|
||||||
res.result = Err(format!("{}", e));
|
|
||||||
None
|
|
||||||
}
|
|
||||||
};
|
|
||||||
let mut res = serde_json::to_string(&res).unwrap();
|
|
||||||
res.push('\n');
|
|
||||||
|
|
||||||
write_stream.write_all(res.as_bytes()).await.unwrap();
|
|
||||||
|
|
||||||
// I want this to come at the very end so that we always send a reponse back.
|
|
||||||
if let Some(command) = command {
|
|
||||||
cmd_tx.send(command).await.unwrap();
|
|
||||||
}
|
|
||||||
}
|
|
||||||
});
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
pub struct DaemonClient {
|
|
||||||
connection: UnixStream,
|
|
||||||
}
|
|
||||||
|
|
||||||
impl DaemonClient {
|
|
||||||
pub async fn new() -> Result<Self> {
|
|
||||||
Self::new_with_path(UNIX_SOCKET_PATH).await
|
|
||||||
}
|
|
||||||
|
|
||||||
pub async fn new_with_path(path: &str) -> Result<Self> {
|
|
||||||
let path = Path::new(path);
|
|
||||||
let connection = UnixStream::connect(path).await?;
|
|
||||||
|
|
||||||
Ok(Self { connection })
|
|
||||||
}
|
|
||||||
|
|
||||||
pub async fn send_command(&mut self, command: DaemonCommand) -> Result<()> {
|
|
||||||
let mut command = serde_json::to_string(&DaemonRequest { id: 0, command })?;
|
|
||||||
command.push('\n');
|
|
||||||
|
|
||||||
self.connection.write_all(command.as_bytes()).await?;
|
|
||||||
let buf_reader = BufReader::new(&mut self.connection);
|
|
||||||
let mut lines = buf_reader.lines();
|
|
||||||
// This unwrap *should* never cause issues.
|
|
||||||
let response = lines.next_line().await?.unwrap();
|
|
||||||
let res: DaemonResponse = serde_json::from_str(&response)?;
|
|
||||||
res.result.unwrap();
|
|
||||||
Ok(())
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
@ -1,17 +0,0 @@
|
||||||
use super::*;
|
|
||||||
|
|
||||||
pub async fn listen(_: mpsc::Sender<DaemonCommand>) -> Result<()> {
|
|
||||||
unimplemented!("This platform does not currently support daemon mode.")
|
|
||||||
}
|
|
||||||
|
|
||||||
pub struct DaemonClient;
|
|
||||||
|
|
||||||
impl DaemonClient {
|
|
||||||
pub async fn new() -> Result<Self> {
|
|
||||||
unimplemented!("This platform does not currently support daemon mode.")
|
|
||||||
}
|
|
||||||
|
|
||||||
pub async fn send_command(&mut self, _: DaemonCommand) -> Result<()> {
|
|
||||||
unimplemented!("This platform does not currently support daemon mode.")
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
@ -1,8 +1,5 @@
|
||||||
use tracing::instrument;
|
|
||||||
|
|
||||||
// Check capabilities on Linux
|
// Check capabilities on Linux
|
||||||
#[cfg(target_os = "linux")]
|
#[cfg(target_os = "linux")]
|
||||||
#[instrument]
|
|
||||||
pub fn ensure_root() {
|
pub fn ensure_root() {
|
||||||
use caps::{has_cap, CapSet, Capability};
|
use caps::{has_cap, CapSet, Capability};
|
||||||
|
|
||||||
|
|
@ -22,7 +19,6 @@ pub fn ensure_root() {
|
||||||
|
|
||||||
// Check for root user on macOS
|
// Check for root user on macOS
|
||||||
#[cfg(target_vendor = "apple")]
|
#[cfg(target_vendor = "apple")]
|
||||||
#[instrument]
|
|
||||||
pub fn ensure_root() {
|
pub fn ensure_root() {
|
||||||
use nix::unistd::Uid;
|
use nix::unistd::Uid;
|
||||||
|
|
||||||
|
|
@ -34,7 +30,6 @@ pub fn ensure_root() {
|
||||||
}
|
}
|
||||||
|
|
||||||
#[cfg(target_family = "windows")]
|
#[cfg(target_family = "windows")]
|
||||||
#[instrument]
|
|
||||||
pub fn ensure_root() {
|
pub fn ensure_root() {
|
||||||
todo!()
|
todo!()
|
||||||
}
|
}
|
||||||
|
|
|
||||||
|
|
@ -1,33 +1 @@
|
||||||
#![deny(missing_debug_implementations)]
|
|
||||||
pub mod ensureroot;
|
pub mod ensureroot;
|
||||||
|
|
||||||
#[cfg(any(target_os = "linux", target_vendor = "apple"))]
|
|
||||||
use std::{
|
|
||||||
mem,
|
|
||||||
os::fd::{AsRawFd, FromRawFd},
|
|
||||||
};
|
|
||||||
|
|
||||||
use tun::TunInterface;
|
|
||||||
|
|
||||||
// TODO Separate start and retrieve functions
|
|
||||||
|
|
||||||
#[cfg(any(target_os = "linux", target_vendor = "apple"))]
|
|
||||||
#[no_mangle]
|
|
||||||
pub extern "C" fn retrieve() -> i32 {
|
|
||||||
let iface2 = (1..100)
|
|
||||||
.filter_map(|i| {
|
|
||||||
let iface = unsafe { TunInterface::from_raw_fd(i) };
|
|
||||||
match iface.name() {
|
|
||||||
Ok(_name) => Some(iface),
|
|
||||||
Err(_) => {
|
|
||||||
mem::forget(iface);
|
|
||||||
None
|
|
||||||
}
|
|
||||||
}
|
|
||||||
})
|
|
||||||
.next();
|
|
||||||
match iface2 {
|
|
||||||
Some(iface) => iface.as_raw_fd(),
|
|
||||||
None => -1,
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
|
||||||
|
|
@ -1,23 +1,7 @@
|
||||||
use anyhow::Context;
|
|
||||||
use std::mem;
|
|
||||||
#[cfg(any(target_os = "linux", target_vendor = "apple"))]
|
|
||||||
use std::os::fd::FromRawFd;
|
|
||||||
|
|
||||||
use clap::{Args, Parser, Subcommand};
|
use clap::{Args, Parser, Subcommand};
|
||||||
use tracing::instrument;
|
|
||||||
|
|
||||||
use tracing_log::LogTracer;
|
|
||||||
use tracing_oslog::OsLogger;
|
|
||||||
use tracing_subscriber::{prelude::*, FmtSubscriber};
|
|
||||||
use tokio::io::Result;
|
use tokio::io::Result;
|
||||||
#[cfg(any(target_os = "linux", target_vendor = "apple"))]
|
|
||||||
use burrow::retrieve;
|
|
||||||
use tun::TunInterface;
|
use tun::TunInterface;
|
||||||
|
|
||||||
mod daemon;
|
|
||||||
|
|
||||||
use daemon::{DaemonClient, DaemonCommand, DaemonStartOptions};
|
|
||||||
|
|
||||||
#[derive(Parser)]
|
#[derive(Parser)]
|
||||||
#[command(name = "Burrow")]
|
#[command(name = "Burrow")]
|
||||||
#[command(author = "Hack Club <team@hackclub.com>")]
|
#[command(author = "Hack Club <team@hackclub.com>")]
|
||||||
|
|
@ -38,108 +22,28 @@ struct Cli {
|
||||||
enum Commands {
|
enum Commands {
|
||||||
/// Start Burrow
|
/// Start Burrow
|
||||||
Start(StartArgs),
|
Start(StartArgs),
|
||||||
/// Retrieve the file descriptor of the tun interface
|
|
||||||
Retrieve(RetrieveArgs),
|
|
||||||
/// Stop Burrow daemon
|
|
||||||
Stop,
|
|
||||||
/// Start Burrow daemon
|
|
||||||
Daemon(DaemonArgs),
|
|
||||||
}
|
}
|
||||||
|
|
||||||
#[derive(Args)]
|
#[derive(Args)]
|
||||||
struct StartArgs {}
|
struct StartArgs {}
|
||||||
|
|
||||||
#[derive(Args)]
|
async fn try_main() -> Result<()> {
|
||||||
struct RetrieveArgs {}
|
|
||||||
|
|
||||||
#[derive(Args)]
|
|
||||||
struct DaemonArgs {}
|
|
||||||
|
|
||||||
#[cfg(any(target_os = "linux", target_vendor = "apple"))]
|
|
||||||
async fn try_start() -> Result<()> {
|
|
||||||
let mut client = DaemonClient::new().await?;
|
|
||||||
client
|
|
||||||
.send_command(DaemonCommand::Start(DaemonStartOptions::default()))
|
|
||||||
.await
|
|
||||||
}
|
|
||||||
|
|
||||||
#[cfg(any(target_os = "linux", target_vendor = "apple"))]
|
|
||||||
#[instrument]
|
|
||||||
async fn try_retrieve() -> Result<()> {
|
|
||||||
LogTracer::init().context("Failed to initialize LogTracer").unwrap();
|
|
||||||
|
|
||||||
if cfg!(target_os = "linux") || cfg!(target_vendor = "apple") {
|
|
||||||
let maybe_layer = system_log().unwrap();
|
|
||||||
if let Some(layer) = maybe_layer {
|
|
||||||
let logger = layer.with_subscriber(FmtSubscriber::new());
|
|
||||||
tracing::subscriber::set_global_default(logger).context("Failed to set the global tracing subscriber").unwrap();
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
burrow::ensureroot::ensure_root();
|
burrow::ensureroot::ensure_root();
|
||||||
let iface2 = retrieve();
|
|
||||||
tracing::info!("{}", iface2);
|
|
||||||
Ok(())
|
|
||||||
}
|
|
||||||
|
|
||||||
#[cfg(any(target_os = "linux", target_vendor = "apple"))]
|
let iface = TunInterface::new()?;
|
||||||
async fn try_stop() -> Result<()> {
|
println!("{:?}", iface.name());
|
||||||
let mut client = DaemonClient::new().await?;
|
|
||||||
client.send_command(DaemonCommand::Stop).await?;
|
|
||||||
Ok(())
|
|
||||||
}
|
|
||||||
|
|
||||||
#[cfg(not(any(target_os = "linux", target_vendor = "apple")))]
|
|
||||||
async fn try_start() -> Result<()> {
|
|
||||||
Ok(())
|
|
||||||
}
|
|
||||||
|
|
||||||
#[cfg(not(any(target_os = "linux", target_vendor = "apple")))]
|
|
||||||
async fn try_retrieve() -> Result<()> {
|
|
||||||
Ok(())
|
|
||||||
}
|
|
||||||
|
|
||||||
#[cfg(not(any(target_os = "linux", target_vendor = "apple")))]
|
|
||||||
async fn try_stop() -> Result<()> {
|
|
||||||
Ok(())
|
Ok(())
|
||||||
}
|
}
|
||||||
|
|
||||||
#[tokio::main(flavor = "current_thread")]
|
#[tokio::main(flavor = "current_thread")]
|
||||||
async fn main() -> Result<()> {
|
async fn main() {
|
||||||
tracing::info!("Platform: {}", std::env::consts::OS);
|
println!("Platform: {}", std::env::consts::OS);
|
||||||
|
|
||||||
let cli = Cli::parse();
|
let cli = Cli::parse();
|
||||||
match &cli.command {
|
match &cli.command {
|
||||||
Commands::Start(..) => {
|
Commands::Start(..) => {
|
||||||
try_start().await.unwrap();
|
try_main().await.unwrap();
|
||||||
tracing::info!("FINISHED");
|
|
||||||
}
|
|
||||||
Commands::Retrieve(..) => {
|
|
||||||
try_retrieve().await.unwrap();
|
|
||||||
tracing::info!("FINISHED");
|
|
||||||
}
|
|
||||||
Commands::Stop => {
|
|
||||||
try_stop().await.unwrap();
|
|
||||||
}
|
|
||||||
Commands::Daemon(_) => daemon::daemon_main().await?,
|
|
||||||
}
|
|
||||||
|
|
||||||
Ok(())
|
|
||||||
}
|
|
||||||
|
|
||||||
#[cfg(target_os = "linux")]
|
|
||||||
fn system_log() -> anyhow::Result<Option<tracing_journald::Layer>> {
|
|
||||||
let maybe_journald = tracing_journald::layer();
|
|
||||||
match maybe_journald {
|
|
||||||
Err(e) if e.kind() == std::io::ErrorKind::NotFound => {
|
|
||||||
tracing::trace!("journald not found");
|
|
||||||
Ok(None)
|
|
||||||
},
|
|
||||||
_ => Ok(Some(maybe_journald?))
|
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
#[cfg(target_vendor = "apple")]
|
|
||||||
fn system_log() -> anyhow::Result<Option<OsLogger>> {
|
|
||||||
Ok(Some(OsLogger::new("com.hackclub.burrow", "default")))
|
|
||||||
}
|
}
|
||||||
|
|
|
||||||
|
|
@ -1,9 +0,0 @@
|
||||||
[Unit]
|
|
||||||
Description=Burrow
|
|
||||||
After=burrow.socket
|
|
||||||
|
|
||||||
[Service]
|
|
||||||
ExecStart=/usr/local/bin/burrow daemon
|
|
||||||
|
|
||||||
[Install]
|
|
||||||
WantedBy=multi-user.target
|
|
||||||
|
|
@ -1,8 +0,0 @@
|
||||||
[Unit]
|
|
||||||
Description=Burrow Socket
|
|
||||||
|
|
||||||
[Socket]
|
|
||||||
ListenStream=/run/burrow.sock
|
|
||||||
|
|
||||||
[Install]
|
|
||||||
WantedBy=sockets.target
|
|
||||||
|
|
@ -10,14 +10,11 @@ nix = { version = "0.26", features = ["ioctl"] }
|
||||||
socket2 = "0.4"
|
socket2 = "0.4"
|
||||||
tokio = { version = "1.28", features = [] }
|
tokio = { version = "1.28", features = [] }
|
||||||
byteorder = "1.4"
|
byteorder = "1.4"
|
||||||
tracing = "0.1"
|
|
||||||
log = "0.4"
|
log = "0.4"
|
||||||
serde = { version = "1", features = ["derive"], optional = true }
|
|
||||||
|
|
||||||
futures = { version = "0.3.28", optional = true }
|
futures = { version = "0.3.28", optional = true }
|
||||||
|
|
||||||
[features]
|
[features]
|
||||||
serde = ["dep:serde"]
|
|
||||||
tokio = ["tokio/net", "dep:futures"]
|
tokio = ["tokio/net", "dep:futures"]
|
||||||
|
|
||||||
[target.'cfg(feature = "tokio")'.dev-dependencies]
|
[target.'cfg(feature = "tokio")'.dev-dependencies]
|
||||||
|
|
|
||||||
|
|
@ -1,5 +1,3 @@
|
||||||
#![deny(missing_debug_implementations)]
|
|
||||||
|
|
||||||
#[cfg(target_os = "windows")]
|
#[cfg(target_os = "windows")]
|
||||||
#[path = "windows/mod.rs"]
|
#[path = "windows/mod.rs"]
|
||||||
mod imp;
|
mod imp;
|
||||||
|
|
|
||||||
|
|
@ -3,8 +3,7 @@ use std::io::Error;
|
||||||
|
|
||||||
use super::TunInterface;
|
use super::TunInterface;
|
||||||
|
|
||||||
#[derive(Debug, Clone, Default)]
|
#[derive(Default)]
|
||||||
#[cfg_attr(feature = "serde", derive(serde::Serialize, serde::Deserialize))]
|
|
||||||
pub struct TunOptions {
|
pub struct TunOptions {
|
||||||
/// (Windows + Linux) Name the tun interface.
|
/// (Windows + Linux) Name the tun interface.
|
||||||
pub(crate) name: Option<String>,
|
pub(crate) name: Option<String>,
|
||||||
|
|
|
||||||
|
|
@ -1,21 +1,17 @@
|
||||||
use std::io;
|
use std::io;
|
||||||
use tokio::io::unix::AsyncFd;
|
use tokio::io::unix::AsyncFd;
|
||||||
use tracing::instrument;
|
|
||||||
|
|
||||||
#[derive(Debug)]
|
|
||||||
pub struct TunInterface {
|
pub struct TunInterface {
|
||||||
inner: AsyncFd<crate::TunInterface>,
|
inner: AsyncFd<crate::TunInterface>,
|
||||||
}
|
}
|
||||||
|
|
||||||
impl TunInterface {
|
impl TunInterface {
|
||||||
#[instrument]
|
|
||||||
pub fn new(tun: crate::TunInterface) -> io::Result<Self> {
|
pub fn new(tun: crate::TunInterface) -> io::Result<Self> {
|
||||||
Ok(Self {
|
Ok(Self {
|
||||||
inner: AsyncFd::new(tun)?,
|
inner: AsyncFd::new(tun)?,
|
||||||
})
|
})
|
||||||
}
|
}
|
||||||
|
|
||||||
#[instrument]
|
|
||||||
pub async fn write(&self, buf: &[u8]) -> io::Result<usize> {
|
pub async fn write(&self, buf: &[u8]) -> io::Result<usize> {
|
||||||
loop {
|
loop {
|
||||||
let mut guard = self.inner.writable().await?;
|
let mut guard = self.inner.writable().await?;
|
||||||
|
|
@ -26,7 +22,6 @@ impl TunInterface {
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
#[instrument]
|
|
||||||
pub async fn read(&mut self, buf: &mut [u8]) -> io::Result<usize> {
|
pub async fn read(&mut self, buf: &mut [u8]) -> io::Result<usize> {
|
||||||
loop {
|
loop {
|
||||||
let mut guard = self.inner.readable_mut().await?;
|
let mut guard = self.inner.readable_mut().await?;
|
||||||
|
|
@ -37,3 +32,27 @@ impl TunInterface {
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
#[cfg(test)]
|
||||||
|
mod tests {
|
||||||
|
use std::net::Ipv4Addr;
|
||||||
|
|
||||||
|
use super::*;
|
||||||
|
#[tokio::test]
|
||||||
|
async fn test_create() {
|
||||||
|
let tun = crate::TunInterface::new().unwrap();
|
||||||
|
let _async_tun = TunInterface::new(tun).unwrap();
|
||||||
|
}
|
||||||
|
|
||||||
|
#[tokio::test]
|
||||||
|
async fn test_write() {
|
||||||
|
let tun = crate::TunInterface::new().unwrap();
|
||||||
|
tun.set_ipv4_addr(Ipv4Addr::from([192, 168, 1, 10]))
|
||||||
|
.unwrap();
|
||||||
|
let async_tun = TunInterface::new(tun).unwrap();
|
||||||
|
let mut buf = [0u8; 1500];
|
||||||
|
buf[0] = 6 << 4;
|
||||||
|
let bytes_written = async_tun.write(&buf).await.unwrap();
|
||||||
|
assert!(bytes_written > 0);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
|
||||||
|
|
@ -1,13 +1,12 @@
|
||||||
use byteorder::{ByteOrder, NetworkEndian};
|
use byteorder::{ByteOrder, NetworkEndian};
|
||||||
use fehler::throws;
|
use fehler::throws;
|
||||||
use libc::{c_char, iovec, writev, AF_INET, AF_INET6};
|
use libc::{c_char, iovec, writev, AF_INET, AF_INET6};
|
||||||
use tracing::info;
|
use log::info;
|
||||||
use socket2::{Domain, SockAddr, Socket, Type};
|
use socket2::{Domain, SockAddr, Socket, Type};
|
||||||
use std::io::IoSlice;
|
use std::io::IoSlice;
|
||||||
use std::net::{Ipv4Addr, SocketAddrV4};
|
use std::net::{Ipv4Addr, SocketAddrV4};
|
||||||
use std::os::fd::{AsRawFd, RawFd};
|
use std::os::fd::{AsRawFd, RawFd};
|
||||||
use std::{io::Error, mem};
|
use std::{io::Error, mem};
|
||||||
use tracing::instrument;
|
|
||||||
|
|
||||||
mod kern_control;
|
mod kern_control;
|
||||||
mod sys;
|
mod sys;
|
||||||
|
|
@ -24,19 +23,16 @@ pub struct TunInterface {
|
||||||
|
|
||||||
impl TunInterface {
|
impl TunInterface {
|
||||||
#[throws]
|
#[throws]
|
||||||
#[instrument]
|
|
||||||
pub fn new() -> TunInterface {
|
pub fn new() -> TunInterface {
|
||||||
Self::new_with_options(TunOptions::new())?
|
Self::new_with_options(TunOptions::new())?
|
||||||
}
|
}
|
||||||
|
|
||||||
#[throws]
|
#[throws]
|
||||||
#[instrument]
|
|
||||||
pub fn new_with_options(_: TunOptions) -> TunInterface {
|
pub fn new_with_options(_: TunOptions) -> TunInterface {
|
||||||
TunInterface::connect(0)?
|
TunInterface::connect(0)?
|
||||||
}
|
}
|
||||||
|
|
||||||
#[throws]
|
#[throws]
|
||||||
#[instrument]
|
|
||||||
fn connect(index: u32) -> TunInterface {
|
fn connect(index: u32) -> TunInterface {
|
||||||
use socket2::{Domain, Protocol, Socket, Type};
|
use socket2::{Domain, Protocol, Socket, Type};
|
||||||
|
|
||||||
|
|
@ -52,7 +48,6 @@ impl TunInterface {
|
||||||
}
|
}
|
||||||
|
|
||||||
#[throws]
|
#[throws]
|
||||||
#[instrument]
|
|
||||||
pub fn name(&self) -> String {
|
pub fn name(&self) -> String {
|
||||||
let mut buf = [0 as c_char; sys::IFNAMSIZ];
|
let mut buf = [0 as c_char; sys::IFNAMSIZ];
|
||||||
let mut len = buf.len() as sys::socklen_t;
|
let mut len = buf.len() as sys::socklen_t;
|
||||||
|
|
@ -67,7 +62,6 @@ impl TunInterface {
|
||||||
}
|
}
|
||||||
|
|
||||||
#[throws]
|
#[throws]
|
||||||
#[instrument]
|
|
||||||
fn ifreq(&self) -> sys::ifreq {
|
fn ifreq(&self) -> sys::ifreq {
|
||||||
let mut iff: sys::ifreq = unsafe { mem::zeroed() };
|
let mut iff: sys::ifreq = unsafe { mem::zeroed() };
|
||||||
iff.ifr_name = string_to_ifname(&self.name()?);
|
iff.ifr_name = string_to_ifname(&self.name()?);
|
||||||
|
|
@ -75,7 +69,6 @@ impl TunInterface {
|
||||||
}
|
}
|
||||||
|
|
||||||
#[throws]
|
#[throws]
|
||||||
#[instrument]
|
|
||||||
pub fn set_ipv4_addr(&self, addr: Ipv4Addr) {
|
pub fn set_ipv4_addr(&self, addr: Ipv4Addr) {
|
||||||
let addr = SockAddr::from(SocketAddrV4::new(addr, 0));
|
let addr = SockAddr::from(SocketAddrV4::new(addr, 0));
|
||||||
let mut iff = self.ifreq()?;
|
let mut iff = self.ifreq()?;
|
||||||
|
|
@ -85,7 +78,6 @@ impl TunInterface {
|
||||||
}
|
}
|
||||||
|
|
||||||
#[throws]
|
#[throws]
|
||||||
#[instrument]
|
|
||||||
pub fn ipv4_addr(&self) -> Ipv4Addr {
|
pub fn ipv4_addr(&self) -> Ipv4Addr {
|
||||||
let mut iff = self.ifreq()?;
|
let mut iff = self.ifreq()?;
|
||||||
self.perform(|fd| unsafe { sys::if_get_addr(fd, &mut iff) })?;
|
self.perform(|fd| unsafe { sys::if_get_addr(fd, &mut iff) })?;
|
||||||
|
|
@ -95,15 +87,11 @@ impl TunInterface {
|
||||||
|
|
||||||
#[throws]
|
#[throws]
|
||||||
fn perform<R>(&self, perform: impl FnOnce(RawFd) -> Result<R, nix::Error>) -> R {
|
fn perform<R>(&self, perform: impl FnOnce(RawFd) -> Result<R, nix::Error>) -> R {
|
||||||
let span = tracing::info_span!("perform", fd = self.as_raw_fd());
|
|
||||||
let _enter = span.enter();
|
|
||||||
|
|
||||||
let socket = Socket::new(Domain::IPV4, Type::DGRAM, None)?;
|
let socket = Socket::new(Domain::IPV4, Type::DGRAM, None)?;
|
||||||
perform(socket.as_raw_fd())?
|
perform(socket.as_raw_fd())?
|
||||||
}
|
}
|
||||||
|
|
||||||
#[throws]
|
#[throws]
|
||||||
#[instrument]
|
|
||||||
pub fn mtu(&self) -> i32 {
|
pub fn mtu(&self) -> i32 {
|
||||||
let mut iff = self.ifreq()?;
|
let mut iff = self.ifreq()?;
|
||||||
self.perform(|fd| unsafe { sys::if_get_mtu(fd, &mut iff) })?;
|
self.perform(|fd| unsafe { sys::if_get_mtu(fd, &mut iff) })?;
|
||||||
|
|
@ -113,7 +101,6 @@ impl TunInterface {
|
||||||
}
|
}
|
||||||
|
|
||||||
#[throws]
|
#[throws]
|
||||||
#[instrument]
|
|
||||||
pub fn set_mtu(&self, mtu: i32) {
|
pub fn set_mtu(&self, mtu: i32) {
|
||||||
let mut iff = self.ifreq()?;
|
let mut iff = self.ifreq()?;
|
||||||
iff.ifr_ifru.ifru_mtu = mtu;
|
iff.ifr_ifru.ifru_mtu = mtu;
|
||||||
|
|
@ -122,7 +109,6 @@ impl TunInterface {
|
||||||
}
|
}
|
||||||
|
|
||||||
#[throws]
|
#[throws]
|
||||||
#[instrument]
|
|
||||||
pub fn netmask(&self) -> Ipv4Addr {
|
pub fn netmask(&self) -> Ipv4Addr {
|
||||||
let mut iff = self.ifreq()?;
|
let mut iff = self.ifreq()?;
|
||||||
self.perform(|fd| unsafe { sys::if_get_netmask(fd, &mut iff) })?;
|
self.perform(|fd| unsafe { sys::if_get_netmask(fd, &mut iff) })?;
|
||||||
|
|
@ -134,7 +120,6 @@ impl TunInterface {
|
||||||
}
|
}
|
||||||
|
|
||||||
#[throws]
|
#[throws]
|
||||||
#[instrument]
|
|
||||||
pub fn set_netmask(&self, addr: Ipv4Addr) {
|
pub fn set_netmask(&self, addr: Ipv4Addr) {
|
||||||
let addr = SockAddr::from(SocketAddrV4::new(addr, 0));
|
let addr = SockAddr::from(SocketAddrV4::new(addr, 0));
|
||||||
let mut iff = self.ifreq()?;
|
let mut iff = self.ifreq()?;
|
||||||
|
|
@ -148,7 +133,6 @@ impl TunInterface {
|
||||||
}
|
}
|
||||||
|
|
||||||
#[throws]
|
#[throws]
|
||||||
#[instrument]
|
|
||||||
pub fn send(&self, buf: &[u8]) -> usize {
|
pub fn send(&self, buf: &[u8]) -> usize {
|
||||||
use std::io::ErrorKind;
|
use std::io::ErrorKind;
|
||||||
let proto = match buf[0] >> 4 {
|
let proto = match buf[0] >> 4 {
|
||||||
|
|
@ -172,3 +156,32 @@ impl TunInterface {
|
||||||
.map_err(|_| Error::new(ErrorKind::Other, "Conversion error"))?
|
.map_err(|_| Error::new(ErrorKind::Other, "Conversion error"))?
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
#[cfg(test)]
|
||||||
|
mod test {
|
||||||
|
use super::*;
|
||||||
|
use std::net::Ipv4Addr;
|
||||||
|
|
||||||
|
#[test]
|
||||||
|
fn mtu() {
|
||||||
|
let interf = TunInterface::new().unwrap();
|
||||||
|
|
||||||
|
interf.set_mtu(500).unwrap();
|
||||||
|
|
||||||
|
assert_eq!(interf.mtu().unwrap(), 500);
|
||||||
|
}
|
||||||
|
|
||||||
|
#[test]
|
||||||
|
#[throws]
|
||||||
|
fn netmask() {
|
||||||
|
let interf = TunInterface::new()?;
|
||||||
|
|
||||||
|
let netmask = Ipv4Addr::new(255, 0, 0, 0);
|
||||||
|
let addr = Ipv4Addr::new(192, 168, 1, 1);
|
||||||
|
|
||||||
|
interf.set_ipv4_addr(addr)?;
|
||||||
|
interf.set_netmask(netmask)?;
|
||||||
|
|
||||||
|
assert_eq!(interf.netmask()?, netmask);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
|
||||||
|
|
@ -8,7 +8,7 @@ use std::net::{Ipv4Addr, Ipv6Addr, SocketAddrV4};
|
||||||
use std::os::fd::RawFd;
|
use std::os::fd::RawFd;
|
||||||
use std::os::unix::io::{AsRawFd, FromRawFd, IntoRawFd};
|
use std::os::unix::io::{AsRawFd, FromRawFd, IntoRawFd};
|
||||||
|
|
||||||
use tracing::{info, instrument};
|
use log::info;
|
||||||
|
|
||||||
use libc::in6_ifreq;
|
use libc::in6_ifreq;
|
||||||
|
|
||||||
|
|
@ -23,13 +23,11 @@ pub struct TunInterface {
|
||||||
|
|
||||||
impl TunInterface {
|
impl TunInterface {
|
||||||
#[throws]
|
#[throws]
|
||||||
#[instrument]
|
|
||||||
pub fn new() -> TunInterface {
|
pub fn new() -> TunInterface {
|
||||||
Self::new_with_options(TunOptions::new())?
|
Self::new_with_options(TunOptions::new())?
|
||||||
}
|
}
|
||||||
|
|
||||||
#[throws]
|
#[throws]
|
||||||
#[instrument]
|
|
||||||
pub(crate) fn new_with_options(options: TunOptions) -> TunInterface {
|
pub(crate) fn new_with_options(options: TunOptions) -> TunInterface {
|
||||||
let file = OpenOptions::new()
|
let file = OpenOptions::new()
|
||||||
.read(true)
|
.read(true)
|
||||||
|
|
@ -61,7 +59,6 @@ impl TunInterface {
|
||||||
}
|
}
|
||||||
|
|
||||||
#[throws]
|
#[throws]
|
||||||
#[instrument]
|
|
||||||
pub fn name(&self) -> String {
|
pub fn name(&self) -> String {
|
||||||
let mut iff = unsafe { mem::zeroed() };
|
let mut iff = unsafe { mem::zeroed() };
|
||||||
unsafe { sys::tun_get_iff(self.socket.as_raw_fd(), &mut iff)? };
|
unsafe { sys::tun_get_iff(self.socket.as_raw_fd(), &mut iff)? };
|
||||||
|
|
@ -69,7 +66,6 @@ impl TunInterface {
|
||||||
}
|
}
|
||||||
|
|
||||||
#[throws]
|
#[throws]
|
||||||
#[instrument]
|
|
||||||
fn ifreq(&self) -> sys::ifreq {
|
fn ifreq(&self) -> sys::ifreq {
|
||||||
let mut iff: sys::ifreq = unsafe { mem::zeroed() };
|
let mut iff: sys::ifreq = unsafe { mem::zeroed() };
|
||||||
iff.ifr_name = string_to_ifname(&self.name()?);
|
iff.ifr_name = string_to_ifname(&self.name()?);
|
||||||
|
|
@ -77,7 +73,6 @@ impl TunInterface {
|
||||||
}
|
}
|
||||||
|
|
||||||
#[throws]
|
#[throws]
|
||||||
#[instrument]
|
|
||||||
fn in6_ifreq(&self) -> in6_ifreq {
|
fn in6_ifreq(&self) -> in6_ifreq {
|
||||||
let mut iff: in6_ifreq = unsafe { mem::zeroed() };
|
let mut iff: in6_ifreq = unsafe { mem::zeroed() };
|
||||||
iff.ifr6_ifindex = self.index()?;
|
iff.ifr6_ifindex = self.index()?;
|
||||||
|
|
@ -85,7 +80,6 @@ impl TunInterface {
|
||||||
}
|
}
|
||||||
|
|
||||||
#[throws]
|
#[throws]
|
||||||
#[instrument]
|
|
||||||
pub fn index(&self) -> i32 {
|
pub fn index(&self) -> i32 {
|
||||||
let mut iff = self.ifreq()?;
|
let mut iff = self.ifreq()?;
|
||||||
self.perform(|fd| unsafe { sys::if_get_index(fd, &mut iff) })?;
|
self.perform(|fd| unsafe { sys::if_get_index(fd, &mut iff) })?;
|
||||||
|
|
@ -93,7 +87,6 @@ impl TunInterface {
|
||||||
}
|
}
|
||||||
|
|
||||||
#[throws]
|
#[throws]
|
||||||
#[instrument]
|
|
||||||
pub fn set_ipv4_addr(&self, addr: Ipv4Addr) {
|
pub fn set_ipv4_addr(&self, addr: Ipv4Addr) {
|
||||||
let addr = SockAddr::from(SocketAddrV4::new(addr, 0));
|
let addr = SockAddr::from(SocketAddrV4::new(addr, 0));
|
||||||
let mut iff = self.ifreq()?;
|
let mut iff = self.ifreq()?;
|
||||||
|
|
@ -103,7 +96,6 @@ impl TunInterface {
|
||||||
}
|
}
|
||||||
|
|
||||||
#[throws]
|
#[throws]
|
||||||
#[instrument]
|
|
||||||
pub fn ipv4_addr(&self) -> Ipv4Addr {
|
pub fn ipv4_addr(&self) -> Ipv4Addr {
|
||||||
let mut iff = self.ifreq()?;
|
let mut iff = self.ifreq()?;
|
||||||
self.perform(|fd| unsafe { sys::if_get_addr(fd, &mut iff) })?;
|
self.perform(|fd| unsafe { sys::if_get_addr(fd, &mut iff) })?;
|
||||||
|
|
@ -112,31 +104,6 @@ impl TunInterface {
|
||||||
}
|
}
|
||||||
|
|
||||||
#[throws]
|
#[throws]
|
||||||
#[instrument]
|
|
||||||
pub fn set_broadcast_addr(&self, addr: Ipv4Addr) {
|
|
||||||
let addr = SockAddr::from(SocketAddrV4::new(addr, 0));
|
|
||||||
let mut iff = self.ifreq()?;
|
|
||||||
iff.ifr_ifru.ifru_broadaddr = unsafe { *addr.as_ptr() };
|
|
||||||
self.perform(|fd| unsafe { sys::if_set_brdaddr(fd, &iff) })?;
|
|
||||||
info!(
|
|
||||||
"broadcast_addr_set: {:?} (fd: {:?})",
|
|
||||||
addr,
|
|
||||||
self.as_raw_fd()
|
|
||||||
)
|
|
||||||
}
|
|
||||||
|
|
||||||
#[throws]
|
|
||||||
#[instrument]
|
|
||||||
pub fn broadcast_addr(&self) -> Ipv4Addr {
|
|
||||||
let mut iff = self.ifreq()?;
|
|
||||||
self.perform(|fd| unsafe { sys::if_get_brdaddr(fd, &mut iff) })?;
|
|
||||||
let addr =
|
|
||||||
unsafe { *(&iff.ifr_ifru.ifru_broadaddr as *const _ as *const sys::sockaddr_in) };
|
|
||||||
Ipv4Addr::from(u32::from_be(addr.sin_addr.s_addr))
|
|
||||||
}
|
|
||||||
|
|
||||||
#[throws]
|
|
||||||
#[instrument]
|
|
||||||
pub fn set_ipv6_addr(&self, addr: Ipv6Addr) {
|
pub fn set_ipv6_addr(&self, addr: Ipv6Addr) {
|
||||||
let mut iff = self.in6_ifreq()?;
|
let mut iff = self.in6_ifreq()?;
|
||||||
iff.ifr6_addr.s6_addr = addr.octets();
|
iff.ifr6_addr.s6_addr = addr.octets();
|
||||||
|
|
@ -145,7 +112,6 @@ impl TunInterface {
|
||||||
}
|
}
|
||||||
|
|
||||||
#[throws]
|
#[throws]
|
||||||
#[instrument]
|
|
||||||
pub fn set_mtu(&self, mtu: i32) {
|
pub fn set_mtu(&self, mtu: i32) {
|
||||||
let mut iff = self.ifreq()?;
|
let mut iff = self.ifreq()?;
|
||||||
iff.ifr_ifru.ifru_mtu = mtu;
|
iff.ifr_ifru.ifru_mtu = mtu;
|
||||||
|
|
@ -154,7 +120,6 @@ impl TunInterface {
|
||||||
}
|
}
|
||||||
|
|
||||||
#[throws]
|
#[throws]
|
||||||
#[instrument]
|
|
||||||
pub fn mtu(&self) -> i32 {
|
pub fn mtu(&self) -> i32 {
|
||||||
let mut iff = self.ifreq()?;
|
let mut iff = self.ifreq()?;
|
||||||
self.perform(|fd| unsafe { sys::if_get_mtu(fd, &mut iff) })?;
|
self.perform(|fd| unsafe { sys::if_get_mtu(fd, &mut iff) })?;
|
||||||
|
|
@ -164,7 +129,6 @@ impl TunInterface {
|
||||||
}
|
}
|
||||||
|
|
||||||
#[throws]
|
#[throws]
|
||||||
#[instrument]
|
|
||||||
pub fn set_netmask(&self, addr: Ipv4Addr) {
|
pub fn set_netmask(&self, addr: Ipv4Addr) {
|
||||||
let addr = SockAddr::from(SocketAddrV4::new(addr, 0));
|
let addr = SockAddr::from(SocketAddrV4::new(addr, 0));
|
||||||
|
|
||||||
|
|
@ -181,7 +145,6 @@ impl TunInterface {
|
||||||
}
|
}
|
||||||
|
|
||||||
#[throws]
|
#[throws]
|
||||||
#[instrument]
|
|
||||||
pub fn netmask(&self) -> Ipv4Addr {
|
pub fn netmask(&self) -> Ipv4Addr {
|
||||||
let mut iff = self.ifreq()?;
|
let mut iff = self.ifreq()?;
|
||||||
self.perform(|fd| unsafe { sys::if_get_netmask(fd, &mut iff) })?;
|
self.perform(|fd| unsafe { sys::if_get_netmask(fd, &mut iff) })?;
|
||||||
|
|
@ -194,25 +157,47 @@ impl TunInterface {
|
||||||
|
|
||||||
#[throws]
|
#[throws]
|
||||||
fn perform<R>(&self, perform: impl FnOnce(RawFd) -> Result<R, nix::Error>) -> R {
|
fn perform<R>(&self, perform: impl FnOnce(RawFd) -> Result<R, nix::Error>) -> R {
|
||||||
let span = tracing::info_span!("perform");
|
|
||||||
let _enter = span.enter();
|
|
||||||
|
|
||||||
let socket = Socket::new(Domain::IPV4, Type::DGRAM, None)?;
|
let socket = Socket::new(Domain::IPV4, Type::DGRAM, None)?;
|
||||||
perform(socket.as_raw_fd())?
|
perform(socket.as_raw_fd())?
|
||||||
}
|
}
|
||||||
|
|
||||||
#[throws]
|
#[throws]
|
||||||
fn perform6<R>(&self, perform: impl FnOnce(RawFd) -> Result<R, nix::Error>) -> R {
|
fn perform6<R>(&self, perform: impl FnOnce(RawFd) -> Result<R, nix::Error>) -> R {
|
||||||
let span = tracing::info_span!("perform");
|
|
||||||
let _enter = span.enter();
|
|
||||||
|
|
||||||
let socket = Socket::new(Domain::IPV6, Type::DGRAM, None)?;
|
let socket = Socket::new(Domain::IPV6, Type::DGRAM, None)?;
|
||||||
perform(socket.as_raw_fd())?
|
perform(socket.as_raw_fd())?
|
||||||
}
|
}
|
||||||
|
|
||||||
#[throws]
|
#[throws]
|
||||||
#[instrument]
|
|
||||||
pub fn send(&self, buf: &[u8]) -> usize {
|
pub fn send(&self, buf: &[u8]) -> usize {
|
||||||
self.socket.send(buf)?
|
self.socket.send(buf)?
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
#[cfg(test)]
|
||||||
|
mod test {
|
||||||
|
use super::TunInterface;
|
||||||
|
use std::net::Ipv4Addr;
|
||||||
|
|
||||||
|
#[test]
|
||||||
|
fn mtu() {
|
||||||
|
let interf = TunInterface::new().unwrap();
|
||||||
|
|
||||||
|
interf.set_mtu(500).unwrap();
|
||||||
|
|
||||||
|
assert_eq!(interf.mtu().unwrap(), 500);
|
||||||
|
}
|
||||||
|
|
||||||
|
#[test]
|
||||||
|
#[throws]
|
||||||
|
fn netmask() {
|
||||||
|
let interf = TunInterface::new()?;
|
||||||
|
|
||||||
|
let netmask = Ipv4Addr::new(255, 0, 0, 0);
|
||||||
|
let addr = Ipv4Addr::new(192, 168, 1, 1);
|
||||||
|
|
||||||
|
interf.set_ipv4_addr(addr)?;
|
||||||
|
interf.set_netmask(netmask)?;
|
||||||
|
|
||||||
|
assert_eq!(interf.netmask()?, netmask);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
|
||||||
|
|
@ -18,12 +18,10 @@ ioctl_read_bad!(
|
||||||
);
|
);
|
||||||
ioctl_read_bad!(if_get_index, libc::SIOCGIFINDEX, libc::ifreq);
|
ioctl_read_bad!(if_get_index, libc::SIOCGIFINDEX, libc::ifreq);
|
||||||
ioctl_read_bad!(if_get_addr, libc::SIOCGIFADDR, libc::ifreq);
|
ioctl_read_bad!(if_get_addr, libc::SIOCGIFADDR, libc::ifreq);
|
||||||
ioctl_read_bad!(if_get_brdaddr, libc::SIOCGIFBRDADDR, libc::ifreq);
|
|
||||||
ioctl_read_bad!(if_get_mtu, libc::SIOCGIFMTU, libc::ifreq);
|
ioctl_read_bad!(if_get_mtu, libc::SIOCGIFMTU, libc::ifreq);
|
||||||
ioctl_read_bad!(if_get_netmask, libc::SIOCGIFNETMASK, libc::ifreq);
|
ioctl_read_bad!(if_get_netmask, libc::SIOCGIFNETMASK, libc::ifreq);
|
||||||
|
|
||||||
ioctl_write_ptr_bad!(if_set_addr, libc::SIOCSIFADDR, libc::ifreq);
|
ioctl_write_ptr_bad!(if_set_addr, libc::SIOCSIFADDR, libc::ifreq);
|
||||||
ioctl_write_ptr_bad!(if_set_addr6, libc::SIOCSIFADDR, libc::in6_ifreq);
|
ioctl_write_ptr_bad!(if_set_addr6, libc::SIOCSIFADDR, libc::in6_ifreq);
|
||||||
ioctl_write_ptr_bad!(if_set_brdaddr, libc::SIOCSIFBRDADDR, libc::ifreq);
|
|
||||||
ioctl_write_ptr_bad!(if_set_mtu, libc::SIOCSIFMTU, libc::ifreq);
|
ioctl_write_ptr_bad!(if_set_mtu, libc::SIOCSIFMTU, libc::ifreq);
|
||||||
ioctl_write_ptr_bad!(if_set_netmask, libc::SIOCSIFNETMASK, libc::ifreq);
|
ioctl_write_ptr_bad!(if_set_netmask, libc::SIOCSIFNETMASK, libc::ifreq);
|
||||||
|
|
|
||||||
|
|
@ -2,7 +2,6 @@ use std::{
|
||||||
io::{Error, Read},
|
io::{Error, Read},
|
||||||
os::fd::{AsRawFd, FromRawFd, IntoRawFd, RawFd},
|
os::fd::{AsRawFd, FromRawFd, IntoRawFd, RawFd},
|
||||||
};
|
};
|
||||||
use tracing::instrument;
|
|
||||||
|
|
||||||
use super::TunOptions;
|
use super::TunOptions;
|
||||||
|
|
||||||
|
|
@ -42,13 +41,11 @@ impl IntoRawFd for TunInterface {
|
||||||
|
|
||||||
impl TunInterface {
|
impl TunInterface {
|
||||||
#[throws]
|
#[throws]
|
||||||
#[instrument]
|
|
||||||
pub fn recv(&mut self, buf: &mut [u8]) -> usize {
|
pub fn recv(&mut self, buf: &mut [u8]) -> usize {
|
||||||
self.socket.read(buf)?
|
self.socket.read(buf)?
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
#[instrument]
|
|
||||||
pub fn ifname_to_string(buf: [libc::c_char; libc::IFNAMSIZ]) -> String {
|
pub fn ifname_to_string(buf: [libc::c_char; libc::IFNAMSIZ]) -> String {
|
||||||
// TODO: Switch to `CStr::from_bytes_until_nul` when stabilized
|
// TODO: Switch to `CStr::from_bytes_until_nul` when stabilized
|
||||||
unsafe {
|
unsafe {
|
||||||
|
|
@ -59,10 +56,44 @@ pub fn ifname_to_string(buf: [libc::c_char; libc::IFNAMSIZ]) -> String {
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
#[instrument]
|
|
||||||
pub fn string_to_ifname(name: &str) -> [libc::c_char; libc::IFNAMSIZ] {
|
pub fn string_to_ifname(name: &str) -> [libc::c_char; libc::IFNAMSIZ] {
|
||||||
let mut buf = [0 as libc::c_char; libc::IFNAMSIZ];
|
let mut buf = [0 as libc::c_char; libc::IFNAMSIZ];
|
||||||
let len = name.len().min(buf.len());
|
let len = name.len().min(buf.len());
|
||||||
buf[..len].copy_from_slice(unsafe { &*(name.as_bytes() as *const _ as *const [libc::c_char]) });
|
buf[..len].copy_from_slice(unsafe { &*(name.as_bytes() as *const _ as *const [libc::c_char]) });
|
||||||
buf
|
buf
|
||||||
}
|
}
|
||||||
|
|
||||||
|
#[cfg(test)]
|
||||||
|
mod test {
|
||||||
|
|
||||||
|
use super::*;
|
||||||
|
|
||||||
|
use std::net::Ipv4Addr;
|
||||||
|
|
||||||
|
#[throws]
|
||||||
|
#[test]
|
||||||
|
fn tst_read() {
|
||||||
|
// This test is interactive, you need to send a packet to any server through 192.168.1.10
|
||||||
|
// EG. `sudo route add 8.8.8.8 192.168.1.10`,
|
||||||
|
//`dig @8.8.8.8 hackclub.com`
|
||||||
|
let mut tun = TunInterface::new()?;
|
||||||
|
println!("tun name: {:?}", tun.name()?);
|
||||||
|
tun.set_ipv4_addr(Ipv4Addr::from([192, 168, 1, 10]))?;
|
||||||
|
println!("tun ip: {:?}", tun.ipv4_addr()?);
|
||||||
|
println!("Waiting for a packet...");
|
||||||
|
let buf = &mut [0u8; 1500];
|
||||||
|
let res = tun.recv(buf);
|
||||||
|
println!("Received!");
|
||||||
|
assert!(res.is_ok());
|
||||||
|
}
|
||||||
|
|
||||||
|
#[test]
|
||||||
|
#[throws]
|
||||||
|
fn write_packets() {
|
||||||
|
let tun = TunInterface::new()?;
|
||||||
|
let mut buf = [0u8; 1500];
|
||||||
|
buf[0] = 6 << 4;
|
||||||
|
let bytes_written = tun.send(&buf)?;
|
||||||
|
assert_eq!(bytes_written, 1504);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
|
||||||
|
|
@ -5,18 +5,15 @@ use std::{
|
||||||
mem::MaybeUninit,
|
mem::MaybeUninit,
|
||||||
os::unix::io::{AsRawFd, IntoRawFd, RawFd},
|
os::unix::io::{AsRawFd, IntoRawFd, RawFd},
|
||||||
};
|
};
|
||||||
use tracing::instrument;
|
|
||||||
|
|
||||||
use crate::TunInterface;
|
use crate::TunInterface;
|
||||||
|
|
||||||
#[derive(Debug)]
|
|
||||||
pub struct TunQueue {
|
pub struct TunQueue {
|
||||||
socket: socket2::Socket,
|
socket: socket2::Socket,
|
||||||
}
|
}
|
||||||
|
|
||||||
impl TunQueue {
|
impl TunQueue {
|
||||||
#[throws]
|
#[throws]
|
||||||
#[instrument]
|
|
||||||
pub fn recv(&self, buf: &mut [MaybeUninit<u8>]) -> usize {
|
pub fn recv(&self, buf: &mut [MaybeUninit<u8>]) -> usize {
|
||||||
self.socket.recv(buf)?
|
self.socket.recv(buf)?
|
||||||
}
|
}
|
||||||
|
|
|
||||||
|
|
@ -1,4 +1,3 @@
|
||||||
use std::fmt::Debug;
|
|
||||||
use fehler::throws;
|
use fehler::throws;
|
||||||
use std::io::Error;
|
use std::io::Error;
|
||||||
use std::ptr;
|
use std::ptr;
|
||||||
|
|
@ -15,15 +14,6 @@ pub struct TunInterface {
|
||||||
name: String,
|
name: String,
|
||||||
}
|
}
|
||||||
|
|
||||||
impl Debug for TunInterface {
|
|
||||||
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
|
|
||||||
f.debug_struct("TunInterface")
|
|
||||||
.field("handle", &"SYS_WINTUN_ADAPTER_HANDLE".to_string())
|
|
||||||
.field("name", &self.name)
|
|
||||||
.finish()
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
impl TunInterface {
|
impl TunInterface {
|
||||||
#[throws]
|
#[throws]
|
||||||
pub fn new() -> TunInterface {
|
pub fn new() -> TunInterface {
|
||||||
|
|
|
||||||
|
|
@ -1,2 +1 @@
|
||||||
#[derive(Debug)]
|
|
||||||
pub struct TunQueue;
|
pub struct TunQueue;
|
||||||
|
|
|
||||||
|
|
@ -1,6 +1,6 @@
|
||||||
use fehler::throws;
|
use fehler::throws;
|
||||||
use std::io::Error;
|
use std::io::Error;
|
||||||
use std::net::Ipv4Addr;
|
use std::net::{Ipv4Addr};
|
||||||
use tun::TunInterface;
|
use tun::TunInterface;
|
||||||
|
|
||||||
#[test]
|
#[test]
|
||||||
|
|
@ -11,22 +11,6 @@ fn test_create() {
|
||||||
|
|
||||||
#[test]
|
#[test]
|
||||||
#[throws]
|
#[throws]
|
||||||
#[cfg(not(any(target_os = "windows", target_vendor = "apple")))]
|
|
||||||
fn test_set_get_broadcast_addr() {
|
|
||||||
let tun = TunInterface::new()?;
|
|
||||||
let addr = Ipv4Addr::new(10, 0, 0, 1);
|
|
||||||
tun.set_ipv4_addr(addr)?;
|
|
||||||
|
|
||||||
let broadcast_addr = Ipv4Addr::new(255, 255, 255, 0);
|
|
||||||
tun.set_broadcast_addr(broadcast_addr)?;
|
|
||||||
let result = tun.broadcast_addr()?;
|
|
||||||
|
|
||||||
assert_eq!(broadcast_addr, result);
|
|
||||||
}
|
|
||||||
|
|
||||||
#[test]
|
|
||||||
#[throws]
|
|
||||||
#[cfg(not(target_os = "windows"))]
|
|
||||||
fn test_set_get_ipv4() {
|
fn test_set_get_ipv4() {
|
||||||
let tun = TunInterface::new()?;
|
let tun = TunInterface::new()?;
|
||||||
|
|
||||||
|
|
@ -39,10 +23,8 @@ fn test_set_get_ipv4() {
|
||||||
|
|
||||||
#[test]
|
#[test]
|
||||||
#[throws]
|
#[throws]
|
||||||
#[cfg(not(any(target_os = "windows", target_vendor = "apple")))]
|
#[cfg(target_os = "linux")]
|
||||||
fn test_set_get_ipv6() {
|
fn test_set_get_ipv6() {
|
||||||
use std::net::Ipv6Addr;
|
|
||||||
|
|
||||||
let tun = TunInterface::new()?;
|
let tun = TunInterface::new()?;
|
||||||
|
|
||||||
let addr = Ipv6Addr::new(1, 1, 1, 1, 1, 1, 1, 1);
|
let addr = Ipv6Addr::new(1, 1, 1, 1, 1, 1, 1, 1);
|
||||||
|
|
@ -51,29 +33,3 @@ fn test_set_get_ipv6() {
|
||||||
// let result = tun.ipv6_addr()?;
|
// let result = tun.ipv6_addr()?;
|
||||||
// assert_eq!(addr, result);
|
// assert_eq!(addr, result);
|
||||||
}
|
}
|
||||||
|
|
||||||
#[test]
|
|
||||||
#[throws]
|
|
||||||
#[cfg(not(target_os = "windows"))]
|
|
||||||
fn test_set_get_mtu() {
|
|
||||||
let interf = TunInterface::new()?;
|
|
||||||
|
|
||||||
interf.set_mtu(500)?;
|
|
||||||
|
|
||||||
assert_eq!(interf.mtu().unwrap(), 500);
|
|
||||||
}
|
|
||||||
|
|
||||||
#[test]
|
|
||||||
#[throws]
|
|
||||||
#[cfg(not(target_os = "windows"))]
|
|
||||||
fn test_set_get_netmask() {
|
|
||||||
let interf = TunInterface::new()?;
|
|
||||||
|
|
||||||
let netmask = Ipv4Addr::new(255, 0, 0, 0);
|
|
||||||
let addr = Ipv4Addr::new(192, 168, 1, 1);
|
|
||||||
|
|
||||||
interf.set_ipv4_addr(addr)?;
|
|
||||||
interf.set_netmask(netmask)?;
|
|
||||||
|
|
||||||
assert_eq!(interf.netmask()?, netmask);
|
|
||||||
}
|
|
||||||
|
|
|
||||||
|
|
@ -1,36 +0,0 @@
|
||||||
use fehler::throws;
|
|
||||||
use std::io::Error;
|
|
||||||
|
|
||||||
use std::net::Ipv4Addr;
|
|
||||||
use tun::TunInterface;
|
|
||||||
|
|
||||||
#[throws]
|
|
||||||
#[test]
|
|
||||||
#[ignore = "requires interactivity"]
|
|
||||||
#[cfg(not(target_os = "windows"))]
|
|
||||||
fn tst_read() {
|
|
||||||
// This test is interactive, you need to send a packet to any server through 192.168.1.10
|
|
||||||
// EG. `sudo route add 8.8.8.8 192.168.1.10`,
|
|
||||||
//`dig @8.8.8.8 hackclub.com`
|
|
||||||
let mut tun = TunInterface::new()?;
|
|
||||||
println!("tun name: {:?}", tun.name()?);
|
|
||||||
tun.set_ipv4_addr(Ipv4Addr::from([192, 168, 1, 10]))?;
|
|
||||||
println!("tun ip: {:?}", tun.ipv4_addr()?);
|
|
||||||
println!("Waiting for a packet...");
|
|
||||||
let buf = &mut [0u8; 1500];
|
|
||||||
let res = tun.recv(buf);
|
|
||||||
println!("Received!");
|
|
||||||
assert!(res.is_ok());
|
|
||||||
}
|
|
||||||
|
|
||||||
#[test]
|
|
||||||
#[throws]
|
|
||||||
#[ignore = "requires interactivity"]
|
|
||||||
#[cfg(not(target_os = "windows"))]
|
|
||||||
fn write_packets() {
|
|
||||||
let tun = TunInterface::new()?;
|
|
||||||
let mut buf = [0u8; 1500];
|
|
||||||
buf[0] = 6 << 4;
|
|
||||||
let bytes_written = tun.send(&buf)?;
|
|
||||||
assert_eq!(bytes_written, 1504);
|
|
||||||
}
|
|
||||||
|
|
@ -1,22 +0,0 @@
|
||||||
use std::net::Ipv4Addr;
|
|
||||||
|
|
||||||
#[tokio::test]
|
|
||||||
#[cfg(all(feature = "tokio", not(target_os = "windows")))]
|
|
||||||
async fn test_create() {
|
|
||||||
let tun = tun::TunInterface::new().unwrap();
|
|
||||||
let async_tun = tun::tokio::TunInterface::new(tun).unwrap();
|
|
||||||
}
|
|
||||||
|
|
||||||
#[tokio::test]
|
|
||||||
#[ignore = "requires interactivity"]
|
|
||||||
#[cfg(all(feature = "tokio", not(target_os = "windows")))]
|
|
||||||
async fn test_write() {
|
|
||||||
let tun = tun::TunInterface::new().unwrap();
|
|
||||||
tun.set_ipv4_addr(Ipv4Addr::from([192, 168, 1, 10]))
|
|
||||||
.unwrap();
|
|
||||||
let async_tun = tun::tokio::TunInterface::new(tun).unwrap();
|
|
||||||
let mut buf = [0u8; 1500];
|
|
||||||
buf[0] = 6 << 4;
|
|
||||||
let bytes_written = async_tun.write(&buf).await.unwrap();
|
|
||||||
assert!(bytes_written > 0);
|
|
||||||
}
|
|
||||||
Loading…
Add table
Add a link
Reference in a new issue