From 1b39eca069edb3d48b27b19971e6627b816e6697 Mon Sep 17 00:00:00 2001 From: Conrad Kramer Date: Sat, 9 Sep 2023 11:16:19 -0700 Subject: [PATCH] boringtun wip --- burrow/src/boringtun/Cargo.toml | 66 ++ burrow/src/boringtun/src/device/mod.rs | 902 +++++++++++++++++ burrow/src/boringtun/src/device/peer.rs | 170 ++++ burrow/src/boringtun/src/lib.rs | 27 + burrow/src/boringtun/src/noise/errors.rs | 23 + burrow/src/boringtun/src/noise/handshake.rs | 941 ++++++++++++++++++ burrow/src/boringtun/src/noise/mod.rs | 799 +++++++++++++++ .../src/boringtun/src/noise/rate_limiter.rs | 193 ++++ burrow/src/boringtun/src/noise/session.rs | 329 ++++++ burrow/src/boringtun/src/noise/timers.rs | 335 +++++++ burrow/src/boringtun/src/serialization.rs | 33 + 11 files changed, 3818 insertions(+) create mode 100755 burrow/src/boringtun/Cargo.toml create mode 100755 burrow/src/boringtun/src/device/mod.rs create mode 100755 burrow/src/boringtun/src/device/peer.rs create mode 100755 burrow/src/boringtun/src/lib.rs create mode 100755 burrow/src/boringtun/src/noise/errors.rs create mode 100755 burrow/src/boringtun/src/noise/handshake.rs create mode 100755 burrow/src/boringtun/src/noise/mod.rs create mode 100755 burrow/src/boringtun/src/noise/rate_limiter.rs create mode 100755 burrow/src/boringtun/src/noise/session.rs create mode 100755 burrow/src/boringtun/src/noise/timers.rs create mode 100755 burrow/src/boringtun/src/serialization.rs diff --git a/burrow/src/boringtun/Cargo.toml b/burrow/src/boringtun/Cargo.toml new file mode 100755 index 0000000..454b8c8 --- /dev/null +++ b/burrow/src/boringtun/Cargo.toml @@ -0,0 +1,66 @@ +[package] +name = "boringtun" +description = "an implementation of the WireGuard® protocol designed for portability and speed" +version = "0.6.0" +authors = [ + "Noah Kennedy ", + "Andy Grover ", + "Jeff Hiner ", +] +license = "BSD-3-Clause" +repository = "https://github.com/cloudflare/boringtun" +documentation = "https://docs.rs/boringtun/0.5.2/boringtun/" +edition = "2018" + +[features] +default = [] +device = ["socket2", "thiserror"] +jni-bindings = ["ffi-bindings", "jni"] +ffi-bindings = ["tracing-subscriber"] +# mocks std::time::Instant with mock_instant +mock-instant = ["mock_instant"] + +[workspace] + +[dependencies] +base64 = "0.13" +hex = "0.4" +untrusted = "0.9.0" +libc = "0.2" +parking_lot = "0.12" +tracing = "0.1.29" +tracing-subscriber = { version = "0.3", features = ["fmt"], optional = true } +ip_network = "0.4.1" +ip_network_table = "0.2.0" +ring = "0.16" +x25519-dalek = { version = "2.0.0", features = [ + "reusable_secrets", + "static_secrets", +] } +rand_core = { version = "0.6.3", features = ["getrandom"] } +chacha20poly1305 = "0.10.0-pre.1" +aead = "0.5.0-pre.2" +blake2 = "0.10" +hmac = "0.12" +jni = { version = "0.19.0", optional = true } +mock_instant = { version = "0.2", optional = true } +socket2 = { version = "0.4.7", features = ["all"], optional = true } +thiserror = { version = "1", optional = true } + +[target.'cfg(unix)'.dependencies] +nix = { version = "0.25", default-features = false, features = [ + "time", + "user", +] } + +[dev-dependencies] +etherparse = "0.12" +tracing-subscriber = "0.3" +criterion = { version = "0.3.5", features = ["html_reports"] } + +[lib] +crate-type = ["staticlib", "cdylib", "rlib"] + +[[bench]] +name = "crypto_benches" +harness = false diff --git a/burrow/src/boringtun/src/device/mod.rs b/burrow/src/boringtun/src/device/mod.rs new file mode 100755 index 0000000..bc14c89 --- /dev/null +++ b/burrow/src/boringtun/src/device/mod.rs @@ -0,0 +1,902 @@ +// Copyright (c) 2019 Cloudflare, Inc. All rights reserved. +// SPDX-License-Identifier: BSD-3-Clause + +pub mod allowed_ips; +pub mod api; +mod dev_lock; +pub mod drop_privileges; +#[cfg(test)] +mod integration_tests; +pub mod peer; + +#[cfg(any(target_os = "macos", target_os = "ios"))] +#[path = "kqueue.rs"] +pub mod poll; + +#[cfg(target_os = "linux")] +#[path = "epoll.rs"] +pub mod poll; + +#[cfg(any(target_os = "macos", target_os = "ios"))] +#[path = "tun_darwin.rs"] +pub mod tun; + +#[cfg(target_os = "linux")] +#[path = "tun_linux.rs"] +pub mod tun; + +use std::collections::HashMap; +use std::io::{self, Write as _}; +use std::mem::MaybeUninit; +use std::net::{IpAddr, Ipv4Addr, Ipv6Addr, SocketAddr, SocketAddrV4, SocketAddrV6}; +use std::os::unix::io::AsRawFd; +use std::sync::atomic::{AtomicUsize, Ordering}; +use std::sync::Arc; +use std::thread; +use std::thread::JoinHandle; + +use crate::noise::errors::WireGuardError; +use crate::noise::handshake::parse_handshake_anon; +use crate::noise::rate_limiter::RateLimiter; +use crate::noise::{Packet, Tunn, TunnResult}; +use crate::x25519; +use allowed_ips::AllowedIps; +use parking_lot::Mutex; +use peer::{AllowedIP, Peer}; +use poll::{EventPoll, EventRef, WaitResult}; +use rand_core::{OsRng, RngCore}; +use socket2::{Domain, Protocol, Type}; +use tun::TunSocket; + +use dev_lock::{Lock, LockReadGuard}; + +const HANDSHAKE_RATE_LIMIT: u64 = 100; // The number of handshakes per second we can tolerate before using cookies + +const MAX_UDP_SIZE: usize = (1 << 16) - 1; +const MAX_ITR: usize = 100; // Number of packets to handle per handler call + +#[derive(Debug, thiserror::Error)] +pub enum Error { + #[error("i/o error: {0}")] + IoError(#[from] io::Error), + #[error("{0}")] + Socket(io::Error), + #[error("{0}")] + Bind(String), + #[error("{0}")] + FCntl(io::Error), + #[error("{0}")] + EventQueue(io::Error), + #[error("{0}")] + IOCtl(io::Error), + #[error("{0}")] + Connect(String), + #[error("{0}")] + SetSockOpt(String), + #[error("Invalid tunnel name")] + InvalidTunnelName, + #[cfg(any(target_os = "macos", target_os = "ios"))] + #[error("{0}")] + GetSockOpt(io::Error), + #[error("{0}")] + GetSockName(String), + #[cfg(target_os = "linux")] + #[error("{0}")] + Timer(io::Error), + #[error("iface read: {0}")] + IfaceRead(io::Error), + #[error("{0}")] + DropPrivileges(String), + #[error("API socket error: {0}")] + ApiSocket(io::Error), +} + +// What the event loop should do after a handler returns +enum Action { + Continue, // Continue the loop + Yield, // Yield the read lock and acquire it again + Exit, // Stop the loop +} + +// Event handler function +type Handler = Box, &mut ThreadData) -> Action + Send + Sync>; + +pub struct DeviceHandle { + device: Arc>, // The interface this handle owns + threads: Vec>, +} + +#[derive(Debug, Clone, Copy)] +pub struct DeviceConfig { + pub n_threads: usize, + pub use_connected_socket: bool, + #[cfg(target_os = "linux")] + pub use_multi_queue: bool, + #[cfg(target_os = "linux")] + pub uapi_fd: i32, +} + +impl Default for DeviceConfig { + fn default() -> Self { + DeviceConfig { + n_threads: 4, + use_connected_socket: true, + #[cfg(target_os = "linux")] + use_multi_queue: true, + #[cfg(target_os = "linux")] + uapi_fd: -1, + } + } +} + +pub struct Device { + key_pair: Option<(x25519::StaticSecret, x25519::PublicKey)>, + queue: Arc>, + + listen_port: u16, + fwmark: Option, + + iface: Arc, + udp4: Option, + udp6: Option, + + yield_notice: Option, + exit_notice: Option, + + peers: HashMap>>, + peers_by_ip: AllowedIps>>, + peers_by_idx: HashMap>>, + next_index: IndexLfsr, + + config: DeviceConfig, + + cleanup_paths: Vec, + + mtu: AtomicUsize, + + rate_limiter: Option>, + + #[cfg(target_os = "linux")] + uapi_fd: i32, +} + +struct ThreadData { + iface: Arc, + src_buf: [u8; MAX_UDP_SIZE], + dst_buf: [u8; MAX_UDP_SIZE], +} + +impl DeviceHandle { + pub fn new(name: &str, config: DeviceConfig) -> Result { + let n_threads = config.n_threads; + let mut wg_interface = Device::new(name, config)?; + wg_interface.open_listen_socket(0)?; // Start listening on a random port + + let interface_lock = Arc::new(Lock::new(wg_interface)); + + let mut threads = vec![]; + + for i in 0..n_threads { + threads.push({ + let dev = Arc::clone(&interface_lock); + thread::spawn(move || DeviceHandle::event_loop(i, &dev)) + }); + } + + Ok(DeviceHandle { + device: interface_lock, + threads, + }) + } + + pub fn wait(&mut self) { + while let Some(thread) = self.threads.pop() { + thread.join().unwrap(); + } + } + + pub fn clean(&mut self) { + for path in &self.device.read().cleanup_paths { + // attempt to remove any file we created in the work dir + let _ = std::fs::remove_file(path); + } + } + + fn event_loop(_i: usize, device: &Lock) { + #[cfg(target_os = "linux")] + let mut thread_local = ThreadData { + src_buf: [0u8; MAX_UDP_SIZE], + dst_buf: [0u8; MAX_UDP_SIZE], + iface: if _i == 0 || !device.read().config.use_multi_queue { + // For the first thread use the original iface + Arc::clone(&device.read().iface) + } else { + // For for the rest create a new iface queue + let iface_local = Arc::new( + TunSocket::new(&device.read().iface.name().unwrap()) + .unwrap() + .set_non_blocking() + .unwrap(), + ); + + device + .read() + .register_iface_handler(Arc::clone(&iface_local)) + .ok(); + + iface_local + }, + }; + + #[cfg(not(target_os = "linux"))] + let mut thread_local = ThreadData { + src_buf: [0u8; MAX_UDP_SIZE], + dst_buf: [0u8; MAX_UDP_SIZE], + iface: Arc::clone(&device.read().iface), + }; + + #[cfg(not(target_os = "linux"))] + let uapi_fd = -1; + #[cfg(target_os = "linux")] + let uapi_fd = device.read().uapi_fd; + + loop { + // The event loop keeps a read lock on the device, because we assume write access is rarely needed + let mut device_lock = device.read(); + let queue = Arc::clone(&device_lock.queue); + + loop { + match queue.wait() { + WaitResult::Ok(handler) => { + let action = (*handler)(&mut device_lock, &mut thread_local); + match action { + Action::Continue => {} + Action::Yield => break, + Action::Exit => { + device_lock.trigger_exit(); + return; + } + } + } + WaitResult::EoF(handler) => { + if uapi_fd >= 0 && uapi_fd == handler.fd() { + device_lock.trigger_exit(); + return; + } + handler.cancel(); + } + WaitResult::Error(e) => tracing::error!(message = "Poll error", error = ?e), + } + } + } + } +} + +impl Drop for DeviceHandle { + fn drop(&mut self) { + self.device.read().trigger_exit(); + self.clean(); + } +} + +impl Device { + fn next_index(&mut self) -> u32 { + self.next_index.next() + } + + fn remove_peer(&mut self, pub_key: &x25519::PublicKey) { + if let Some(peer) = self.peers.remove(pub_key) { + // Found a peer to remove, now purge all references to it: + { + let p = peer.lock(); + p.shutdown_endpoint(); // close open udp socket and free the closure + self.peers_by_idx.remove(&p.index()); + } + self.peers_by_ip + .remove(&|p: &Arc>| Arc::ptr_eq(&peer, p)); + + tracing::info!("Peer removed"); + } + } + + #[allow(clippy::too_many_arguments)] + fn update_peer( + &mut self, + pub_key: x25519::PublicKey, + remove: bool, + _replace_ips: bool, + endpoint: Option, + allowed_ips: &[AllowedIP], + keepalive: Option, + preshared_key: Option<[u8; 32]>, + ) { + if remove { + // Completely remove a peer + return self.remove_peer(&pub_key); + } + + // Update an existing peer + if self.peers.get(&pub_key).is_some() { + // We already have a peer, we need to merge the existing config into the newly created one + panic!("Modifying existing peers is not yet supported. Remove and add again instead."); + } + + let next_index = self.next_index(); + let device_key_pair = self + .key_pair + .as_ref() + .expect("Private key must be set first"); + + let tunn = Tunn::new( + device_key_pair.0.clone(), + pub_key, + preshared_key, + keepalive, + next_index, + None, + ) + .unwrap(); + + let peer = Peer::new(tunn, next_index, endpoint, allowed_ips, preshared_key); + + let peer = Arc::new(Mutex::new(peer)); + self.peers.insert(pub_key, Arc::clone(&peer)); + self.peers_by_idx.insert(next_index, Arc::clone(&peer)); + + for AllowedIP { addr, cidr } in allowed_ips { + self.peers_by_ip + .insert(*addr, *cidr as _, Arc::clone(&peer)); + } + + tracing::info!("Peer added"); + } + + pub fn new(name: &str, config: DeviceConfig) -> Result { + let poll = EventPoll::::new()?; + + // Create a tunnel device + let iface = Arc::new(TunSocket::new(name)?.set_non_blocking()?); + let mtu = iface.mtu()?; + + #[cfg(not(target_os = "linux"))] + let uapi_fd = -1; + #[cfg(target_os = "linux")] + let uapi_fd = config.uapi_fd; + + let mut device = Device { + queue: Arc::new(poll), + iface, + config, + exit_notice: Default::default(), + yield_notice: Default::default(), + fwmark: Default::default(), + key_pair: Default::default(), + listen_port: Default::default(), + next_index: Default::default(), + peers: Default::default(), + peers_by_idx: Default::default(), + peers_by_ip: AllowedIps::new(), + udp4: Default::default(), + udp6: Default::default(), + cleanup_paths: Default::default(), + mtu: AtomicUsize::new(mtu), + rate_limiter: None, + #[cfg(target_os = "linux")] + uapi_fd, + }; + + if uapi_fd >= 0 { + device.register_api_fd(uapi_fd)?; + } else { + device.register_api_handler()?; + } + device.register_iface_handler(Arc::clone(&device.iface))?; + device.register_notifiers()?; + device.register_timers()?; + + #[cfg(target_os = "macos")] + { + // Only for macOS write the actual socket name into WG_TUN_NAME_FILE + if let Ok(name_file) = std::env::var("WG_TUN_NAME_FILE") { + if name == "utun" { + std::fs::write(&name_file, device.iface.name().unwrap().as_bytes()).unwrap(); + device.cleanup_paths.push(name_file); + } + } + } + + Ok(device) + } + + fn open_listen_socket(&mut self, mut port: u16) -> Result<(), Error> { + // Binds the network facing interfaces + // First close any existing open socket, and remove them from the event loop + if let Some(s) = self.udp4.take() { + unsafe { + // This is safe because the event loop is not running yet + self.queue.clear_event_by_fd(s.as_raw_fd()) + } + }; + + if let Some(s) = self.udp6.take() { + unsafe { self.queue.clear_event_by_fd(s.as_raw_fd()) }; + } + + for peer in self.peers.values() { + peer.lock().shutdown_endpoint(); + } + + // Then open new sockets and bind to the port + let udp_sock4 = socket2::Socket::new(Domain::IPV4, Type::DGRAM, Some(Protocol::UDP))?; + udp_sock4.set_reuse_address(true)?; + udp_sock4.bind(&SocketAddrV4::new(Ipv4Addr::UNSPECIFIED, port).into())?; + udp_sock4.set_nonblocking(true)?; + + if port == 0 { + // Random port was assigned + port = udp_sock4.local_addr()?.as_socket().unwrap().port(); + } + + let udp_sock6 = socket2::Socket::new(Domain::IPV6, Type::DGRAM, Some(Protocol::UDP))?; + udp_sock6.set_reuse_address(true)?; + udp_sock6.bind(&SocketAddrV6::new(Ipv6Addr::UNSPECIFIED, port, 0, 0).into())?; + udp_sock6.set_nonblocking(true)?; + + self.register_udp_handler(udp_sock4.try_clone().unwrap())?; + self.register_udp_handler(udp_sock6.try_clone().unwrap())?; + self.udp4 = Some(udp_sock4); + self.udp6 = Some(udp_sock6); + + self.listen_port = port; + + Ok(()) + } + + fn set_key(&mut self, private_key: x25519::StaticSecret) { + let mut bad_peers = vec![]; + + let public_key = x25519::PublicKey::from(&private_key); + let key_pair = Some((private_key.clone(), public_key)); + + // x25519 (rightly) doesn't let us expose secret keys for comparison. + // If the public keys are the same, then the private keys are the same. + if Some(&public_key) == self.key_pair.as_ref().map(|p| &p.1) { + return; + } + + let rate_limiter = Arc::new(RateLimiter::new(&public_key, HANDSHAKE_RATE_LIMIT)); + + for peer in self.peers.values_mut() { + let mut peer_mut = peer.lock(); + + if peer_mut + .tunnel + .set_static_private( + private_key.clone(), + public_key, + Some(Arc::clone(&rate_limiter)), + ) + .is_err() + { + // In case we encounter an error, we will remove that peer + // An error will be a result of bad public key/secret key combination + bad_peers.push(Arc::clone(peer)); + } + } + + self.key_pair = key_pair; + self.rate_limiter = Some(rate_limiter); + + // Remove all the bad peers + for _ in bad_peers { + unimplemented!(); + } + } + + #[cfg(any(target_os = "android", target_os = "fuchsia", target_os = "linux"))] + fn set_fwmark(&mut self, mark: u32) -> Result<(), Error> { + self.fwmark = Some(mark); + + // First set fwmark on listeners + if let Some(ref sock) = self.udp4 { + sock.set_mark(mark)?; + } + + if let Some(ref sock) = self.udp6 { + sock.set_mark(mark)?; + } + + // Then on all currently connected sockets + for peer in self.peers.values() { + if let Some(ref sock) = peer.lock().endpoint().conn { + sock.set_mark(mark)? + } + } + + Ok(()) + } + + fn clear_peers(&mut self) { + self.peers.clear(); + self.peers_by_idx.clear(); + self.peers_by_ip.clear(); + } + + fn register_notifiers(&mut self) -> Result<(), Error> { + let yield_ev = self + .queue + // The notification event handler simply returns Action::Yield + .new_notifier(Box::new(|_, _| Action::Yield))?; + self.yield_notice = Some(yield_ev); + + let exit_ev = self + .queue + // The exit event handler simply returns Action::Exit + .new_notifier(Box::new(|_, _| Action::Exit))?; + self.exit_notice = Some(exit_ev); + Ok(()) + } + + fn register_timers(&self) -> Result<(), Error> { + self.queue.new_periodic_event( + // Reset the rate limiter every second give or take + Box::new(|d, _| { + if let Some(r) = d.rate_limiter.as_ref() { + r.reset_count() + } + Action::Continue + }), + std::time::Duration::from_secs(1), + )?; + + self.queue.new_periodic_event( + // Execute the timed function of every peer in the list + Box::new(|d, t| { + let peer_map = &d.peers; + + let (udp4, udp6) = match (d.udp4.as_ref(), d.udp6.as_ref()) { + (Some(udp4), Some(udp6)) => (udp4, udp6), + _ => return Action::Continue, + }; + + // Go over each peer and invoke the timer function + for peer in peer_map.values() { + let mut p = peer.lock(); + let endpoint_addr = match p.endpoint().addr { + Some(addr) => addr, + None => continue, + }; + + match p.update_timers(&mut t.dst_buf[..]) { + TunnResult::Done => {} + TunnResult::Err(WireGuardError::ConnectionExpired) => { + p.shutdown_endpoint(); // close open udp socket + } + TunnResult::Err(e) => tracing::error!(message = "Timer error", error = ?e), + TunnResult::WriteToNetwork(packet) => { + match endpoint_addr { + SocketAddr::V4(_) => { + udp4.send_to(packet, &endpoint_addr.into()).ok() + } + SocketAddr::V6(_) => { + udp6.send_to(packet, &endpoint_addr.into()).ok() + } + }; + } + _ => panic!("Unexpected result from update_timers"), + }; + } + Action::Continue + }), + std::time::Duration::from_millis(250), + )?; + Ok(()) + } + + pub(crate) fn trigger_yield(&self) { + self.queue + .trigger_notification(self.yield_notice.as_ref().unwrap()) + } + + pub(crate) fn trigger_exit(&self) { + self.queue + .trigger_notification(self.exit_notice.as_ref().unwrap()) + } + + pub(crate) fn cancel_yield(&self) { + self.queue + .stop_notification(self.yield_notice.as_ref().unwrap()) + } + + fn register_udp_handler(&self, udp: socket2::Socket) -> Result<(), Error> { + self.queue.new_event( + udp.as_raw_fd(), + Box::new(move |d, t| { + // Handler that handles anonymous packets over UDP + let mut iter = MAX_ITR; + let (private_key, public_key) = d.key_pair.as_ref().expect("Key not set"); + + let rate_limiter = d.rate_limiter.as_ref().unwrap(); + + // Loop while we have packets on the anonymous connection + + // Safety: the `recv_from` implementation promises not to write uninitialised + // bytes to the buffer, so this casting is safe. + let src_buf = + unsafe { &mut *(&mut t.src_buf[..] as *mut [u8] as *mut [MaybeUninit]) }; + while let Ok((packet_len, addr)) = udp.recv_from(src_buf) { + let packet = &t.src_buf[..packet_len]; + // The rate limiter initially checks mac1 and mac2, and optionally asks to send a cookie + let parsed_packet = match rate_limiter.verify_packet( + Some(addr.as_socket().unwrap().ip()), + packet, + &mut t.dst_buf, + ) { + Ok(packet) => packet, + Err(TunnResult::WriteToNetwork(cookie)) => { + let _: Result<_, _> = udp.send_to(cookie, &addr); + continue; + } + Err(_) => continue, + }; + + let peer = match &parsed_packet { + Packet::HandshakeInit(p) => { + parse_handshake_anon(private_key, public_key, p) + .ok() + .and_then(|hh| { + d.peers.get(&x25519::PublicKey::from(hh.peer_static_public)) + }) + } + Packet::HandshakeResponse(p) => d.peers_by_idx.get(&(p.receiver_idx >> 8)), + Packet::PacketCookieReply(p) => d.peers_by_idx.get(&(p.receiver_idx >> 8)), + Packet::PacketData(p) => d.peers_by_idx.get(&(p.receiver_idx >> 8)), + }; + + let peer = match peer { + None => continue, + Some(peer) => peer, + }; + + let mut p = peer.lock(); + + // We found a peer, use it to decapsulate the message+ + let mut flush = false; // Are there packets to send from the queue? + match p + .tunnel + .handle_verified_packet(parsed_packet, &mut t.dst_buf[..]) + { + TunnResult::Done => {} + TunnResult::Err(_) => continue, + TunnResult::WriteToNetwork(packet) => { + flush = true; + let _: Result<_, _> = udp.send_to(packet, &addr); + } + TunnResult::WriteToTunnelV4(packet, addr) => { + if p.is_allowed_ip(addr) { + t.iface.write4(packet); + } + } + TunnResult::WriteToTunnelV6(packet, addr) => { + if p.is_allowed_ip(addr) { + t.iface.write6(packet); + } + } + }; + + if flush { + // Flush pending queue + while let TunnResult::WriteToNetwork(packet) = + p.tunnel.decapsulate(None, &[], &mut t.dst_buf[..]) + { + let _: Result<_, _> = udp.send_to(packet, &addr); + } + } + + // This packet was OK, that means we want to create a connected socket for this peer + let addr = addr.as_socket().unwrap(); + let ip_addr = addr.ip(); + p.set_endpoint(addr); + if d.config.use_connected_socket { + if let Ok(sock) = p.connect_endpoint(d.listen_port, d.fwmark) { + d.register_conn_handler(Arc::clone(peer), sock, ip_addr) + .unwrap(); + } + } + + iter -= 1; + if iter == 0 { + break; + } + } + Action::Continue + }), + )?; + Ok(()) + } + + fn register_conn_handler( + &self, + peer: Arc>, + udp: socket2::Socket, + peer_addr: IpAddr, + ) -> Result<(), Error> { + self.queue.new_event( + udp.as_raw_fd(), + Box::new(move |_, t| { + // The conn_handler handles packet received from a connected UDP socket, associated + // with a known peer, this saves us the hustle of finding the right peer. If another + // peer gets the same ip, it will be ignored until the socket does not expire. + let iface = &t.iface; + let mut iter = MAX_ITR; + + // Safety: the `recv_from` implementation promises not to write uninitialised + // bytes to the buffer, so this casting is safe. + let src_buf = + unsafe { &mut *(&mut t.src_buf[..] as *mut [u8] as *mut [MaybeUninit]) }; + + while let Ok(read_bytes) = udp.recv(src_buf) { + let mut flush = false; + let mut p = peer.lock(); + match p.tunnel.decapsulate( + Some(peer_addr), + &t.src_buf[..read_bytes], + &mut t.dst_buf[..], + ) { + TunnResult::Done => {} + TunnResult::Err(e) => eprintln!("Decapsulate error {:?}", e), + TunnResult::WriteToNetwork(packet) => { + flush = true; + let _: Result<_, _> = udp.send(packet); + } + TunnResult::WriteToTunnelV4(packet, addr) => { + if p.is_allowed_ip(addr) { + iface.write4(packet); + } + } + TunnResult::WriteToTunnelV6(packet, addr) => { + if p.is_allowed_ip(addr) { + iface.write6(packet); + } + } + }; + + if flush { + // Flush pending queue + while let TunnResult::WriteToNetwork(packet) = + p.tunnel.decapsulate(None, &[], &mut t.dst_buf[..]) + { + let _: Result<_, _> = udp.send(packet); + } + } + + iter -= 1; + if iter == 0 { + break; + } + } + Action::Continue + }), + )?; + Ok(()) + } + + fn register_iface_handler(&self, iface: Arc) -> Result<(), Error> { + self.queue.new_event( + iface.as_raw_fd(), + Box::new(move |d, t| { + // The iface_handler handles packets received from the WireGuard virtual network + // interface. The flow is as follows: + // * Read a packet + // * Determine peer based on packet destination ip + // * Encapsulate the packet for the given peer + // * Send encapsulated packet to the peer's endpoint + let mtu = d.mtu.load(Ordering::Relaxed); + + let udp4 = d.udp4.as_ref().expect("Not connected"); + let udp6 = d.udp6.as_ref().expect("Not connected"); + + let peers = &d.peers_by_ip; + for _ in 0..MAX_ITR { + let src = match iface.read(&mut t.src_buf[..mtu]) { + Ok(src) => src, + Err(Error::IfaceRead(e)) => { + let ek = e.kind(); + if ek == io::ErrorKind::Interrupted || ek == io::ErrorKind::WouldBlock { + break; + } + eprintln!("Fatal read error on tun interface: {:?}", e); + return Action::Exit; + } + Err(e) => { + eprintln!("Unexpected error on tun interface: {:?}", e); + return Action::Exit; + } + }; + + let dst_addr = match Tunn::dst_address(src) { + Some(addr) => addr, + None => continue, + }; + + let mut peer = match peers.find(dst_addr) { + Some(peer) => peer.lock(), + None => continue, + }; + + match peer.tunnel.encapsulate(src, &mut t.dst_buf[..]) { + TunnResult::Done => {} + TunnResult::Err(e) => { + tracing::error!(message = "Encapsulate error", error = ?e) + } + TunnResult::WriteToNetwork(packet) => { + let mut endpoint = peer.endpoint_mut(); + if let Some(conn) = endpoint.conn.as_mut() { + // Prefer to send using the connected socket + let _: Result<_, _> = conn.write(packet); + } else if let Some(addr @ SocketAddr::V4(_)) = endpoint.addr { + let _: Result<_, _> = udp4.send_to(packet, &addr.into()); + } else if let Some(addr @ SocketAddr::V6(_)) = endpoint.addr { + let _: Result<_, _> = udp6.send_to(packet, &addr.into()); + } else { + tracing::error!("No endpoint"); + } + } + _ => panic!("Unexpected result from encapsulate"), + }; + } + Action::Continue + }), + )?; + Ok(()) + } +} + +/// A basic linear-feedback shift register implemented as xorshift, used to +/// distribute peer indexes across the 24-bit address space reserved for peer +/// identification. +/// The purpose is to obscure the total number of peers using the system and to +/// ensure it requires a non-trivial amount of processing power and/or samples +/// to guess other peers' indices. Anything more ambitious than this is wasted +/// with only 24 bits of space. +struct IndexLfsr { + initial: u32, + lfsr: u32, + mask: u32, +} + +impl IndexLfsr { + /// Generate a random 24-bit nonzero integer + fn random_index() -> u32 { + const LFSR_MAX: u32 = 0xffffff; // 24-bit seed + loop { + let i = OsRng.next_u32() & LFSR_MAX; + if i > 0 { + // LFSR seed must be non-zero + return i; + } + } + } + + /// Generate the next value in the pseudorandom sequence + fn next(&mut self) -> u32 { + // 24-bit polynomial for randomness. This is arbitrarily chosen to + // inject bitflips into the value. + const LFSR_POLY: u32 = 0xd80000; // 24-bit polynomial + let value = self.lfsr - 1; // lfsr will never have value of 0 + self.lfsr = (self.lfsr >> 1) ^ ((0u32.wrapping_sub(self.lfsr & 1u32)) & LFSR_POLY); + assert!(self.lfsr != self.initial, "Too many peers created"); + value ^ self.mask + } +} + +impl Default for IndexLfsr { + fn default() -> Self { + let seed = Self::random_index(); + IndexLfsr { + initial: seed, + lfsr: seed, + mask: Self::random_index(), + } + } +} diff --git a/burrow/src/boringtun/src/device/peer.rs b/burrow/src/boringtun/src/device/peer.rs new file mode 100755 index 0000000..d7f2c22 --- /dev/null +++ b/burrow/src/boringtun/src/device/peer.rs @@ -0,0 +1,170 @@ +// Copyright (c) 2019 Cloudflare, Inc. All rights reserved. +// SPDX-License-Identifier: BSD-3-Clause + +use parking_lot::RwLock; +use socket2::{Domain, Protocol, Type}; + +use std::net::{IpAddr, Ipv4Addr, Ipv6Addr, Shutdown, SocketAddr, SocketAddrV4, SocketAddrV6}; +use std::str::FromStr; + +use crate::device::{AllowedIps, Error}; +use crate::noise::{Tunn, TunnResult}; + +#[derive(Default, Debug)] +pub struct Endpoint { + pub addr: Option, + pub conn: Option, +} + +pub struct Peer { + /// The associated tunnel struct + pub(crate) tunnel: Tunn, + /// The index the tunnel uses + index: u32, + endpoint: RwLock, + allowed_ips: AllowedIps<()>, + preshared_key: Option<[u8; 32]>, +} + +#[derive(Copy, Clone, Ord, PartialOrd, Eq, PartialEq, Hash, Debug)] +pub struct AllowedIP { + pub addr: IpAddr, + pub cidr: u8, +} + +impl FromStr for AllowedIP { + type Err = String; + + fn from_str(s: &str) -> Result { + let ip: Vec<&str> = s.split('/').collect(); + if ip.len() != 2 { + return Err("Invalid IP format".to_owned()); + } + + let (addr, cidr) = (ip[0].parse::(), ip[1].parse::()); + match (addr, cidr) { + (Ok(addr @ IpAddr::V4(_)), Ok(cidr)) if cidr <= 32 => Ok(AllowedIP { addr, cidr }), + (Ok(addr @ IpAddr::V6(_)), Ok(cidr)) if cidr <= 128 => Ok(AllowedIP { addr, cidr }), + _ => Err("Invalid IP format".to_owned()), + } + } +} + +impl Peer { + pub fn new( + tunnel: Tunn, + index: u32, + endpoint: Option, + allowed_ips: &[AllowedIP], + preshared_key: Option<[u8; 32]>, + ) -> Peer { + Peer { + tunnel, + index, + endpoint: RwLock::new(Endpoint { + addr: endpoint, + conn: None, + }), + allowed_ips: allowed_ips.iter().map(|ip| (ip, ())).collect(), + preshared_key, + } + } + + pub fn update_timers<'a>(&mut self, dst: &'a mut [u8]) -> TunnResult<'a> { + self.tunnel.update_timers(dst) + } + + pub fn endpoint(&self) -> parking_lot::RwLockReadGuard<'_, Endpoint> { + self.endpoint.read() + } + + pub(crate) fn endpoint_mut(&self) -> parking_lot::RwLockWriteGuard<'_, Endpoint> { + self.endpoint.write() + } + + pub fn shutdown_endpoint(&self) { + if let Some(conn) = self.endpoint.write().conn.take() { + tracing::info!("Disconnecting from endpoint"); + conn.shutdown(Shutdown::Both).unwrap(); + } + } + + pub fn set_endpoint(&self, addr: SocketAddr) { + let mut endpoint = self.endpoint.write(); + if endpoint.addr != Some(addr) { + // We only need to update the endpoint if it differs from the current one + if let Some(conn) = endpoint.conn.take() { + conn.shutdown(Shutdown::Both).unwrap(); + } + + endpoint.addr = Some(addr); + } + } + + pub fn connect_endpoint( + &self, + port: u16, + fwmark: Option, + ) -> Result { + let mut endpoint = self.endpoint.write(); + + if endpoint.conn.is_some() { + return Err(Error::Connect("Connected".to_owned())); + } + + let addr = endpoint + .addr + .expect("Attempt to connect to undefined endpoint"); + + let udp_conn = + socket2::Socket::new(Domain::for_address(addr), Type::STREAM, Some(Protocol::UDP))?; + udp_conn.set_reuse_address(true)?; + let bind_addr = if addr.is_ipv4() { + SocketAddrV4::new(Ipv4Addr::UNSPECIFIED, port).into() + } else { + SocketAddrV6::new(Ipv6Addr::UNSPECIFIED, port, 0, 0).into() + }; + udp_conn.bind(&bind_addr)?; + udp_conn.connect(&addr.into())?; + udp_conn.set_nonblocking(true)?; + + #[cfg(any(target_os = "android", target_os = "fuchsia", target_os = "linux"))] + if let Some(fwmark) = fwmark { + udp_conn.set_mark(fwmark)?; + } + + tracing::info!( + message="Connected endpoint", + port=port, + endpoint=?endpoint.addr.unwrap() + ); + + endpoint.conn = Some(udp_conn.try_clone().unwrap()); + + Ok(udp_conn) + } + + pub fn is_allowed_ip>(&self, addr: I) -> bool { + self.allowed_ips.find(addr.into()).is_some() + } + + pub fn allowed_ips(&self) -> impl Iterator + '_ { + self.allowed_ips.iter().map(|(_, ip, cidr)| (ip, cidr)) + } + + pub fn time_since_last_handshake(&self) -> Option { + self.tunnel.time_since_last_handshake() + } + + pub fn persistent_keepalive(&self) -> Option { + self.tunnel.persistent_keepalive() + } + + pub fn preshared_key(&self) -> Option<&[u8; 32]> { + self.preshared_key.as_ref() + } + + pub fn index(&self) -> u32 { + self.index + } +} diff --git a/burrow/src/boringtun/src/lib.rs b/burrow/src/boringtun/src/lib.rs new file mode 100755 index 0000000..6ab410d --- /dev/null +++ b/burrow/src/boringtun/src/lib.rs @@ -0,0 +1,27 @@ +// Copyright (c) 2019 Cloudflare, Inc. All rights reserved. +// SPDX-License-Identifier: BSD-3-Clause + +//! Simple implementation of the client-side of the WireGuard protocol. +//! +//! git clone https://github.com/cloudflare/boringtun.git + +#[cfg(feature = "device")] +pub mod device; + +#[cfg(feature = "ffi-bindings")] +pub mod ffi; +#[cfg(feature = "jni-bindings")] +pub mod jni; +pub mod noise; + +#[cfg(not(feature = "mock-instant"))] +pub(crate) mod sleepyinstant; + +pub(crate) mod serialization; + +/// Re-export of the x25519 types +pub mod x25519 { + pub use x25519_dalek::{ + EphemeralSecret, PublicKey, ReusableSecret, SharedSecret, StaticSecret, + }; +} diff --git a/burrow/src/boringtun/src/noise/errors.rs b/burrow/src/boringtun/src/noise/errors.rs new file mode 100755 index 0000000..10513ae --- /dev/null +++ b/burrow/src/boringtun/src/noise/errors.rs @@ -0,0 +1,23 @@ +// Copyright (c) 2019 Cloudflare, Inc. All rights reserved. +// SPDX-License-Identifier: BSD-3-Clause + +#[derive(Debug)] +pub enum WireGuardError { + DestinationBufferTooSmall, + IncorrectPacketLength, + UnexpectedPacket, + WrongPacketType, + WrongIndex, + WrongKey, + InvalidTai64nTimestamp, + WrongTai64nTimestamp, + InvalidMac, + InvalidAeadTag, + InvalidCounter, + DuplicateCounter, + InvalidPacket, + NoCurrentSession, + LockFailed, + ConnectionExpired, + UnderLoad, +} diff --git a/burrow/src/boringtun/src/noise/handshake.rs b/burrow/src/boringtun/src/noise/handshake.rs new file mode 100755 index 0000000..b7c9373 --- /dev/null +++ b/burrow/src/boringtun/src/noise/handshake.rs @@ -0,0 +1,941 @@ +// Copyright (c) 2019 Cloudflare, Inc. All rights reserved. +// SPDX-License-Identifier: BSD-3-Clause + +use super::{HandshakeInit, HandshakeResponse, PacketCookieReply}; +use crate::noise::errors::WireGuardError; +use crate::noise::session::Session; +#[cfg(not(feature = "mock-instant"))] +use crate::sleepyinstant::Instant; +use crate::x25519; +use aead::{Aead, Payload}; +use blake2::digest::{FixedOutput, KeyInit}; +use blake2::{Blake2s256, Blake2sMac, Digest}; +use chacha20poly1305::XChaCha20Poly1305; +use rand_core::OsRng; +use ring::aead::{Aad, LessSafeKey, Nonce, UnboundKey, CHACHA20_POLY1305}; +use std::convert::TryInto; +use std::time::{Duration, SystemTime}; + +#[cfg(feature = "mock-instant")] +use mock_instant::Instant; + +pub(crate) const LABEL_MAC1: &[u8; 8] = b"mac1----"; +pub(crate) const LABEL_COOKIE: &[u8; 8] = b"cookie--"; +const KEY_LEN: usize = 32; +const TIMESTAMP_LEN: usize = 12; + +// initiator.chaining_key = HASH(CONSTRUCTION) +const INITIAL_CHAIN_KEY: [u8; KEY_LEN] = [ + 96, 226, 109, 174, 243, 39, 239, 192, 46, 195, 53, 226, 160, 37, 210, 208, 22, 235, 66, 6, 248, + 114, 119, 245, 45, 56, 209, 152, 139, 120, 205, 54, +]; + +// initiator.chaining_hash = HASH(initiator.chaining_key || IDENTIFIER) +const INITIAL_CHAIN_HASH: [u8; KEY_LEN] = [ + 34, 17, 179, 97, 8, 26, 197, 102, 105, 18, 67, 219, 69, 138, 213, 50, 45, 156, 108, 102, 34, + 147, 232, 183, 14, 225, 156, 101, 186, 7, 158, 243, +]; + +#[inline] +pub(crate) fn b2s_hash(data1: &[u8], data2: &[u8]) -> [u8; 32] { + let mut hash = Blake2s256::new(); + hash.update(data1); + hash.update(data2); + hash.finalize().into() +} + +#[inline] +/// RFC 2401 HMAC+Blake2s, not to be confused with *keyed* Blake2s +pub(crate) fn b2s_hmac(key: &[u8], data1: &[u8]) -> [u8; 32] { + use blake2::digest::Update; + type HmacBlake2s = hmac::SimpleHmac; + let mut hmac = HmacBlake2s::new_from_slice(key).unwrap(); + hmac.update(data1); + hmac.finalize_fixed().into() +} + +#[inline] +/// Like b2s_hmac, but chain data1 and data2 together +pub(crate) fn b2s_hmac2(key: &[u8], data1: &[u8], data2: &[u8]) -> [u8; 32] { + use blake2::digest::Update; + type HmacBlake2s = hmac::SimpleHmac; + let mut hmac = HmacBlake2s::new_from_slice(key).unwrap(); + hmac.update(data1); + hmac.update(data2); + hmac.finalize_fixed().into() +} + +#[inline] +pub(crate) fn b2s_keyed_mac_16(key: &[u8], data1: &[u8]) -> [u8; 16] { + let mut hmac = Blake2sMac::new_from_slice(key).unwrap(); + blake2::digest::Update::update(&mut hmac, data1); + hmac.finalize_fixed().into() +} + +#[inline] +pub(crate) fn b2s_keyed_mac_16_2(key: &[u8], data1: &[u8], data2: &[u8]) -> [u8; 16] { + let mut hmac = Blake2sMac::new_from_slice(key).unwrap(); + blake2::digest::Update::update(&mut hmac, data1); + blake2::digest::Update::update(&mut hmac, data2); + hmac.finalize_fixed().into() +} + +pub(crate) fn b2s_mac_24(key: &[u8], data1: &[u8]) -> [u8; 24] { + let mut hmac = Blake2sMac::new_from_slice(key).unwrap(); + blake2::digest::Update::update(&mut hmac, data1); + hmac.finalize_fixed().into() +} + +#[inline] +/// This wrapper involves an extra copy and MAY BE SLOWER +fn aead_chacha20_seal(ciphertext: &mut [u8], key: &[u8], counter: u64, data: &[u8], aad: &[u8]) { + let mut nonce: [u8; 12] = [0; 12]; + nonce[4..12].copy_from_slice(&counter.to_le_bytes()); + + aead_chacha20_seal_inner(ciphertext, key, nonce, data, aad) +} + +#[inline] +fn aead_chacha20_seal_inner( + ciphertext: &mut [u8], + key: &[u8], + nonce: [u8; 12], + data: &[u8], + aad: &[u8], +) { + let key = LessSafeKey::new(UnboundKey::new(&CHACHA20_POLY1305, key).unwrap()); + + ciphertext[..data.len()].copy_from_slice(data); + + let tag = key + .seal_in_place_separate_tag( + Nonce::assume_unique_for_key(nonce), + Aad::from(aad), + &mut ciphertext[..data.len()], + ) + .unwrap(); + + ciphertext[data.len()..].copy_from_slice(tag.as_ref()); +} + +#[inline] +/// This wrapper involves an extra copy and MAY BE SLOWER +fn aead_chacha20_open( + buffer: &mut [u8], + key: &[u8], + counter: u64, + data: &[u8], + aad: &[u8], +) -> Result<(), WireGuardError> { + let mut nonce: [u8; 12] = [0; 12]; + nonce[4..].copy_from_slice(&counter.to_le_bytes()); + + aead_chacha20_open_inner(buffer, key, nonce, data, aad) + .map_err(|_| WireGuardError::InvalidAeadTag)?; + Ok(()) +} + +#[inline] +fn aead_chacha20_open_inner( + buffer: &mut [u8], + key: &[u8], + nonce: [u8; 12], + data: &[u8], + aad: &[u8], +) -> Result<(), ring::error::Unspecified> { + let key = LessSafeKey::new(UnboundKey::new(&CHACHA20_POLY1305, key).unwrap()); + + let mut inner_buffer = data.to_owned(); + + let plaintext = key.open_in_place( + Nonce::assume_unique_for_key(nonce), + Aad::from(aad), + &mut inner_buffer, + )?; + + buffer.copy_from_slice(plaintext); + + Ok(()) +} + +#[derive(Debug)] +/// This struct represents a 12 byte [Tai64N](https://cr.yp.to/libtai/tai64.html) timestamp +struct Tai64N { + secs: u64, + nano: u32, +} + +#[derive(Debug)] +/// This struct computes a [Tai64N](https://cr.yp.to/libtai/tai64.html) timestamp from current system time +struct TimeStamper { + duration_at_start: Duration, + instant_at_start: Instant, +} + +impl TimeStamper { + /// Create a new TimeStamper + pub fn new() -> TimeStamper { + TimeStamper { + duration_at_start: SystemTime::now() + .duration_since(SystemTime::UNIX_EPOCH) + .unwrap(), + instant_at_start: Instant::now(), + } + } + + /// Take time reading and generate a 12 byte timestamp + pub fn stamp(&self) -> [u8; 12] { + const TAI64_BASE: u64 = (1u64 << 62) + 37; + let mut ext_stamp = [0u8; 12]; + let stamp = Instant::now().duration_since(self.instant_at_start) + self.duration_at_start; + ext_stamp[0..8].copy_from_slice(&(stamp.as_secs() + TAI64_BASE).to_be_bytes()); + ext_stamp[8..12].copy_from_slice(&stamp.subsec_nanos().to_be_bytes()); + ext_stamp + } +} + +impl Tai64N { + /// A zeroed out timestamp + fn zero() -> Tai64N { + Tai64N { secs: 0, nano: 0 } + } + + /// Parse a timestamp from a 12 byte u8 slice + fn parse(buf: &[u8; 12]) -> Result { + if buf.len() < 12 { + return Err(WireGuardError::InvalidTai64nTimestamp); + } + + let (sec_bytes, nano_bytes) = buf.split_at(std::mem::size_of::()); + let secs = u64::from_be_bytes(sec_bytes.try_into().unwrap()); + let nano = u32::from_be_bytes(nano_bytes.try_into().unwrap()); + + // WireGuard does not actually expect tai64n timestamp, just monotonically increasing one + //if secs < (1u64 << 62) || secs >= (1u64 << 63) { + // return Err(WireGuardError::InvalidTai64nTimestamp); + //}; + //if nano >= 1_000_000_000 { + // return Err(WireGuardError::InvalidTai64nTimestamp); + //} + + Ok(Tai64N { secs, nano }) + } + + /// Check if this timestamp represents a time that is chronologically after the time represented + /// by the other timestamp + pub fn after(&self, other: &Tai64N) -> bool { + (self.secs > other.secs) || ((self.secs == other.secs) && (self.nano > other.nano)) + } +} + +/// Parameters used by the noise protocol +struct NoiseParams { + /// Our static public key + static_public: x25519::PublicKey, + /// Our static private key + static_private: x25519::StaticSecret, + /// Static public key of the other party + peer_static_public: x25519::PublicKey, + /// A shared key = DH(static_private, peer_static_public) + static_shared: x25519::SharedSecret, + /// A pre-computation of HASH("mac1----", peer_static_public) for this peer + sending_mac1_key: [u8; KEY_LEN], + /// An optional preshared key + preshared_key: Option<[u8; KEY_LEN]>, +} + +impl std::fmt::Debug for NoiseParams { + fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + f.debug_struct("NoiseParams") + .field("static_public", &self.static_public) + .field("static_private", &"") + .field("peer_static_public", &self.peer_static_public) + .field("static_shared", &"") + .field("sending_mac1_key", &self.sending_mac1_key) + .field("preshared_key", &self.preshared_key) + .finish() + } +} + +struct HandshakeInitSentState { + local_index: u32, + hash: [u8; KEY_LEN], + chaining_key: [u8; KEY_LEN], + ephemeral_private: x25519::ReusableSecret, + time_sent: Instant, +} + +impl std::fmt::Debug for HandshakeInitSentState { + fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + f.debug_struct("HandshakeInitSentState") + .field("local_index", &self.local_index) + .field("hash", &self.hash) + .field("chaining_key", &self.chaining_key) + .field("ephemeral_private", &"") + .field("time_sent", &self.time_sent) + .finish() + } +} + +#[derive(Debug)] +enum HandshakeState { + /// No handshake in process + None, + /// We initiated the handshake + InitSent(HandshakeInitSentState), + /// Handshake initiated by peer + InitReceived { + hash: [u8; KEY_LEN], + chaining_key: [u8; KEY_LEN], + peer_ephemeral_public: x25519::PublicKey, + peer_index: u32, + }, + /// Handshake was established too long ago (implies no handshake is in progress) + Expired, +} + +pub struct Handshake { + params: NoiseParams, + /// Index of the next session + next_index: u32, + /// Allow to have two outgoing handshakes in flight, because sometimes we may receive a delayed response to a handshake with bad networks + previous: HandshakeState, + /// Current handshake state + state: HandshakeState, + cookies: Cookies, + /// The timestamp of the last handshake we received + last_handshake_timestamp: Tai64N, + // TODO: make TimeStamper a singleton + stamper: TimeStamper, + pub(super) last_rtt: Option, +} + +#[derive(Default)] +struct Cookies { + last_mac1: Option<[u8; 16]>, + index: u32, + write_cookie: Option<[u8; 16]>, +} + +#[derive(Debug)] +pub struct HalfHandshake { + pub peer_index: u32, + pub peer_static_public: [u8; 32], +} + +pub fn parse_handshake_anon( + static_private: &x25519::StaticSecret, + static_public: &x25519::PublicKey, + packet: &HandshakeInit, +) -> Result { + let peer_index = packet.sender_idx; + // initiator.chaining_key = HASH(CONSTRUCTION) + let mut chaining_key = INITIAL_CHAIN_KEY; + // initiator.hash = HASH(HASH(initiator.chaining_key || IDENTIFIER) || responder.static_public) + let mut hash = INITIAL_CHAIN_HASH; + hash = b2s_hash(&hash, static_public.as_bytes()); + // msg.unencrypted_ephemeral = DH_PUBKEY(initiator.ephemeral_private) + let peer_ephemeral_public = x25519::PublicKey::from(*packet.unencrypted_ephemeral); + // initiator.hash = HASH(initiator.hash || msg.unencrypted_ephemeral) + hash = b2s_hash(&hash, peer_ephemeral_public.as_bytes()); + // temp = HMAC(initiator.chaining_key, msg.unencrypted_ephemeral) + // initiator.chaining_key = HMAC(temp, 0x1) + chaining_key = b2s_hmac( + &b2s_hmac(&chaining_key, peer_ephemeral_public.as_bytes()), + &[0x01], + ); + // temp = HMAC(initiator.chaining_key, DH(initiator.ephemeral_private, responder.static_public)) + let ephemeral_shared = static_private.diffie_hellman(&peer_ephemeral_public); + let temp = b2s_hmac(&chaining_key, &ephemeral_shared.to_bytes()); + // initiator.chaining_key = HMAC(temp, 0x1) + chaining_key = b2s_hmac(&temp, &[0x01]); + // key = HMAC(temp, initiator.chaining_key || 0x2) + let key = b2s_hmac2(&temp, &chaining_key, &[0x02]); + + let mut peer_static_public = [0u8; KEY_LEN]; + // msg.encrypted_static = AEAD(key, 0, initiator.static_public, initiator.hash) + aead_chacha20_open( + &mut peer_static_public, + &key, + 0, + packet.encrypted_static, + &hash, + )?; + + Ok(HalfHandshake { + peer_index, + peer_static_public, + }) +} + +impl NoiseParams { + /// New noise params struct from our secret key, peers public key, and optional preshared key + fn new( + static_private: x25519::StaticSecret, + static_public: x25519::PublicKey, + peer_static_public: x25519::PublicKey, + preshared_key: Option<[u8; 32]>, + ) -> Result { + let static_shared = static_private.diffie_hellman(&peer_static_public); + + let initial_sending_mac_key = b2s_hash(LABEL_MAC1, peer_static_public.as_bytes()); + + Ok(NoiseParams { + static_public, + static_private, + peer_static_public, + static_shared, + sending_mac1_key: initial_sending_mac_key, + preshared_key, + }) + } + + /// Set a new private key + fn set_static_private( + &mut self, + static_private: x25519::StaticSecret, + static_public: x25519::PublicKey, + ) -> Result<(), WireGuardError> { + // Check that the public key indeed matches the private key + let check_key = x25519::PublicKey::from(&static_private); + assert_eq!(check_key.as_bytes(), static_public.as_bytes()); + + self.static_private = static_private; + self.static_public = static_public; + + self.static_shared = self.static_private.diffie_hellman(&self.peer_static_public); + Ok(()) + } +} + +impl Handshake { + pub(crate) fn new( + static_private: x25519::StaticSecret, + static_public: x25519::PublicKey, + peer_static_public: x25519::PublicKey, + global_idx: u32, + preshared_key: Option<[u8; 32]>, + ) -> Result { + let params = NoiseParams::new( + static_private, + static_public, + peer_static_public, + preshared_key, + )?; + + Ok(Handshake { + params, + next_index: global_idx, + previous: HandshakeState::None, + state: HandshakeState::None, + last_handshake_timestamp: Tai64N::zero(), + stamper: TimeStamper::new(), + cookies: Default::default(), + last_rtt: None, + }) + } + + pub(crate) fn is_in_progress(&self) -> bool { + !matches!(self.state, HandshakeState::None | HandshakeState::Expired) + } + + pub(crate) fn timer(&self) -> Option { + match self.state { + HandshakeState::InitSent(HandshakeInitSentState { time_sent, .. }) => Some(time_sent), + _ => None, + } + } + + pub(crate) fn set_expired(&mut self) { + self.previous = HandshakeState::Expired; + self.state = HandshakeState::Expired; + } + + pub(crate) fn is_expired(&self) -> bool { + matches!(self.state, HandshakeState::Expired) + } + + pub(crate) fn has_cookie(&self) -> bool { + self.cookies.write_cookie.is_some() + } + + pub(crate) fn clear_cookie(&mut self) { + self.cookies.write_cookie = None; + } + + // The index used is 24 bits for peer index, allowing for 16M active peers per server and 8 bits for cyclic session index + fn inc_index(&mut self) -> u32 { + let index = self.next_index; + let idx8 = index as u8; + self.next_index = (index & !0xff) | u32::from(idx8.wrapping_add(1)); + self.next_index + } + + pub(crate) fn set_static_private( + &mut self, + private_key: x25519::StaticSecret, + public_key: x25519::PublicKey, + ) -> Result<(), WireGuardError> { + self.params.set_static_private(private_key, public_key) + } + + pub(super) fn receive_handshake_initialization<'a>( + &mut self, + packet: HandshakeInit, + dst: &'a mut [u8], + ) -> Result<(&'a mut [u8], Session), WireGuardError> { + // initiator.chaining_key = HASH(CONSTRUCTION) + let mut chaining_key = INITIAL_CHAIN_KEY; + // initiator.hash = HASH(HASH(initiator.chaining_key || IDENTIFIER) || responder.static_public) + let mut hash = INITIAL_CHAIN_HASH; + hash = b2s_hash(&hash, self.params.static_public.as_bytes()); + // msg.sender_index = little_endian(initiator.sender_index) + let peer_index = packet.sender_idx; + // msg.unencrypted_ephemeral = DH_PUBKEY(initiator.ephemeral_private) + let peer_ephemeral_public = x25519::PublicKey::from(*packet.unencrypted_ephemeral); + // initiator.hash = HASH(initiator.hash || msg.unencrypted_ephemeral) + hash = b2s_hash(&hash, peer_ephemeral_public.as_bytes()); + // temp = HMAC(initiator.chaining_key, msg.unencrypted_ephemeral) + // initiator.chaining_key = HMAC(temp, 0x1) + chaining_key = b2s_hmac( + &b2s_hmac(&chaining_key, peer_ephemeral_public.as_bytes()), + &[0x01], + ); + // temp = HMAC(initiator.chaining_key, DH(initiator.ephemeral_private, responder.static_public)) + let ephemeral_shared = self + .params + .static_private + .diffie_hellman(&peer_ephemeral_public); + let temp = b2s_hmac(&chaining_key, &ephemeral_shared.to_bytes()); + // initiator.chaining_key = HMAC(temp, 0x1) + chaining_key = b2s_hmac(&temp, &[0x01]); + // key = HMAC(temp, initiator.chaining_key || 0x2) + let key = b2s_hmac2(&temp, &chaining_key, &[0x02]); + + let mut peer_static_public_decrypted = [0u8; KEY_LEN]; + // msg.encrypted_static = AEAD(key, 0, initiator.static_public, initiator.hash) + aead_chacha20_open( + &mut peer_static_public_decrypted, + &key, + 0, + packet.encrypted_static, + &hash, + )?; + + ring::constant_time::verify_slices_are_equal( + self.params.peer_static_public.as_bytes(), + &peer_static_public_decrypted, + ) + .map_err(|_| WireGuardError::WrongKey)?; + + // initiator.hash = HASH(initiator.hash || msg.encrypted_static) + hash = b2s_hash(&hash, packet.encrypted_static); + // temp = HMAC(initiator.chaining_key, DH(initiator.static_private, responder.static_public)) + let temp = b2s_hmac(&chaining_key, self.params.static_shared.as_bytes()); + // initiator.chaining_key = HMAC(temp, 0x1) + chaining_key = b2s_hmac(&temp, &[0x01]); + // key = HMAC(temp, initiator.chaining_key || 0x2) + let key = b2s_hmac2(&temp, &chaining_key, &[0x02]); + // msg.encrypted_timestamp = AEAD(key, 0, TAI64N(), initiator.hash) + let mut timestamp = [0u8; TIMESTAMP_LEN]; + aead_chacha20_open(&mut timestamp, &key, 0, packet.encrypted_timestamp, &hash)?; + + let timestamp = Tai64N::parse(×tamp)?; + if !timestamp.after(&self.last_handshake_timestamp) { + // Possibly a replay + return Err(WireGuardError::WrongTai64nTimestamp); + } + self.last_handshake_timestamp = timestamp; + + // initiator.hash = HASH(initiator.hash || msg.encrypted_timestamp) + hash = b2s_hash(&hash, packet.encrypted_timestamp); + + self.previous = std::mem::replace( + &mut self.state, + HandshakeState::InitReceived { + chaining_key, + hash, + peer_ephemeral_public, + peer_index, + }, + ); + + self.format_handshake_response(dst) + } + + pub(super) fn receive_handshake_response( + &mut self, + packet: HandshakeResponse, + ) -> Result { + // Check if there is a handshake awaiting a response and return the correct one + let (state, is_previous) = match (&self.state, &self.previous) { + (HandshakeState::InitSent(s), _) if s.local_index == packet.receiver_idx => (s, false), + (_, HandshakeState::InitSent(s)) if s.local_index == packet.receiver_idx => (s, true), + _ => return Err(WireGuardError::UnexpectedPacket), + }; + + let peer_index = packet.sender_idx; + let local_index = state.local_index; + + let unencrypted_ephemeral = x25519::PublicKey::from(*packet.unencrypted_ephemeral); + // msg.unencrypted_ephemeral = DH_PUBKEY(responder.ephemeral_private) + // responder.hash = HASH(responder.hash || msg.unencrypted_ephemeral) + let mut hash = b2s_hash(&state.hash, unencrypted_ephemeral.as_bytes()); + // temp = HMAC(responder.chaining_key, msg.unencrypted_ephemeral) + let temp = b2s_hmac(&state.chaining_key, unencrypted_ephemeral.as_bytes()); + // responder.chaining_key = HMAC(temp, 0x1) + let mut chaining_key = b2s_hmac(&temp, &[0x01]); + // temp = HMAC(responder.chaining_key, DH(responder.ephemeral_private, initiator.ephemeral_public)) + let ephemeral_shared = state + .ephemeral_private + .diffie_hellman(&unencrypted_ephemeral); + let temp = b2s_hmac(&chaining_key, &ephemeral_shared.to_bytes()); + // responder.chaining_key = HMAC(temp, 0x1) + chaining_key = b2s_hmac(&temp, &[0x01]); + // temp = HMAC(responder.chaining_key, DH(responder.ephemeral_private, initiator.static_public)) + let temp = b2s_hmac( + &chaining_key, + &self + .params + .static_private + .diffie_hellman(&unencrypted_ephemeral) + .to_bytes(), + ); + // responder.chaining_key = HMAC(temp, 0x1) + chaining_key = b2s_hmac(&temp, &[0x01]); + // temp = HMAC(responder.chaining_key, preshared_key) + let temp = b2s_hmac( + &chaining_key, + &self.params.preshared_key.unwrap_or([0u8; 32])[..], + ); + // responder.chaining_key = HMAC(temp, 0x1) + chaining_key = b2s_hmac(&temp, &[0x01]); + // temp2 = HMAC(temp, responder.chaining_key || 0x2) + let temp2 = b2s_hmac2(&temp, &chaining_key, &[0x02]); + // key = HMAC(temp, temp2 || 0x3) + let key = b2s_hmac2(&temp, &temp2, &[0x03]); + // responder.hash = HASH(responder.hash || temp2) + hash = b2s_hash(&hash, &temp2); + // msg.encrypted_nothing = AEAD(key, 0, [empty], responder.hash) + aead_chacha20_open(&mut [], &key, 0, packet.encrypted_nothing, &hash)?; + + // responder.hash = HASH(responder.hash || msg.encrypted_nothing) + // hash = b2s_hash(hash, buf[ENC_NOTHING_OFF..ENC_NOTHING_OFF + ENC_NOTHING_SZ]); + + // Derive keys + // temp1 = HMAC(initiator.chaining_key, [empty]) + // temp2 = HMAC(temp1, 0x1) + // temp3 = HMAC(temp1, temp2 || 0x2) + // initiator.sending_key = temp2 + // initiator.receiving_key = temp3 + // initiator.sending_key_counter = 0 + // initiator.receiving_key_counter = 0 + let temp1 = b2s_hmac(&chaining_key, &[]); + let temp2 = b2s_hmac(&temp1, &[0x01]); + let temp3 = b2s_hmac2(&temp1, &temp2, &[0x02]); + + let rtt_time = Instant::now().duration_since(state.time_sent); + self.last_rtt = Some(rtt_time.as_millis() as u32); + + if is_previous { + self.previous = HandshakeState::None; + } else { + self.state = HandshakeState::None; + } + Ok(Session::new(local_index, peer_index, temp3, temp2)) + } + + pub(super) fn receive_cookie_reply( + &mut self, + packet: PacketCookieReply, + ) -> Result<(), WireGuardError> { + let mac1 = match self.cookies.last_mac1 { + Some(mac) => mac, + None => { + return Err(WireGuardError::UnexpectedPacket); + } + }; + + let local_index = self.cookies.index; + if packet.receiver_idx != local_index { + return Err(WireGuardError::WrongIndex); + } + // msg.encrypted_cookie = XAEAD(HASH(LABEL_COOKIE || responder.static_public), msg.nonce, cookie, last_received_msg.mac1) + let key = b2s_hash(LABEL_COOKIE, self.params.peer_static_public.as_bytes()); // TODO: pre-compute + + let payload = Payload { + aad: &mac1[0..16], + msg: packet.encrypted_cookie, + }; + let plaintext = XChaCha20Poly1305::new_from_slice(&key) + .unwrap() + .decrypt(packet.nonce.into(), payload) + .map_err(|_| WireGuardError::InvalidAeadTag)?; + + let cookie = plaintext + .try_into() + .map_err(|_| WireGuardError::InvalidPacket)?; + self.cookies.write_cookie = Some(cookie); + Ok(()) + } + + // Compute and append mac1 and mac2 to a handshake message + fn append_mac1_and_mac2<'a>( + &mut self, + local_index: u32, + dst: &'a mut [u8], + ) -> Result<&'a mut [u8], WireGuardError> { + let mac1_off = dst.len() - 32; + let mac2_off = dst.len() - 16; + + // msg.mac1 = MAC(HASH(LABEL_MAC1 || responder.static_public), msg[0:offsetof(msg.mac1)]) + let msg_mac1 = b2s_keyed_mac_16(&self.params.sending_mac1_key, &dst[..mac1_off]); + + dst[mac1_off..mac2_off].copy_from_slice(&msg_mac1[..]); + + //msg.mac2 = MAC(initiator.last_received_cookie, msg[0:offsetof(msg.mac2)]) + let msg_mac2: [u8; 16] = if let Some(cookie) = self.cookies.write_cookie { + b2s_keyed_mac_16(&cookie, &dst[..mac2_off]) + } else { + [0u8; 16] + }; + + dst[mac2_off..].copy_from_slice(&msg_mac2[..]); + + self.cookies.index = local_index; + self.cookies.last_mac1 = Some(msg_mac1); + Ok(dst) + } + + pub(super) fn format_handshake_initiation<'a>( + &mut self, + dst: &'a mut [u8], + ) -> Result<&'a mut [u8], WireGuardError> { + if dst.len() < super::HANDSHAKE_INIT_SZ { + return Err(WireGuardError::DestinationBufferTooSmall); + } + + let (message_type, rest) = dst.split_at_mut(4); + let (sender_index, rest) = rest.split_at_mut(4); + let (unencrypted_ephemeral, rest) = rest.split_at_mut(32); + let (encrypted_static, rest) = rest.split_at_mut(32 + 16); + let (encrypted_timestamp, _) = rest.split_at_mut(12 + 16); + + let local_index = self.inc_index(); + + // initiator.chaining_key = HASH(CONSTRUCTION) + let mut chaining_key = INITIAL_CHAIN_KEY; + // initiator.hash = HASH(HASH(initiator.chaining_key || IDENTIFIER) || responder.static_public) + let mut hash = INITIAL_CHAIN_HASH; + hash = b2s_hash(&hash, self.params.peer_static_public.as_bytes()); + // initiator.ephemeral_private = DH_GENERATE() + let ephemeral_private = x25519::ReusableSecret::random_from_rng(OsRng); + // msg.message_type = 1 + // msg.reserved_zero = { 0, 0, 0 } + message_type.copy_from_slice(&super::HANDSHAKE_INIT.to_le_bytes()); + // msg.sender_index = little_endian(initiator.sender_index) + sender_index.copy_from_slice(&local_index.to_le_bytes()); + // msg.unencrypted_ephemeral = DH_PUBKEY(initiator.ephemeral_private) + unencrypted_ephemeral + .copy_from_slice(x25519::PublicKey::from(&ephemeral_private).as_bytes()); + // initiator.hash = HASH(initiator.hash || msg.unencrypted_ephemeral) + hash = b2s_hash(&hash, unencrypted_ephemeral); + // temp = HMAC(initiator.chaining_key, msg.unencrypted_ephemeral) + // initiator.chaining_key = HMAC(temp, 0x1) + chaining_key = b2s_hmac(&b2s_hmac(&chaining_key, unencrypted_ephemeral), &[0x01]); + // temp = HMAC(initiator.chaining_key, DH(initiator.ephemeral_private, responder.static_public)) + let ephemeral_shared = ephemeral_private.diffie_hellman(&self.params.peer_static_public); + let temp = b2s_hmac(&chaining_key, &ephemeral_shared.to_bytes()); + // initiator.chaining_key = HMAC(temp, 0x1) + chaining_key = b2s_hmac(&temp, &[0x01]); + // key = HMAC(temp, initiator.chaining_key || 0x2) + let key = b2s_hmac2(&temp, &chaining_key, &[0x02]); + // msg.encrypted_static = AEAD(key, 0, initiator.static_public, initiator.hash) + aead_chacha20_seal( + encrypted_static, + &key, + 0, + self.params.static_public.as_bytes(), + &hash, + ); + // initiator.hash = HASH(initiator.hash || msg.encrypted_static) + hash = b2s_hash(&hash, encrypted_static); + // temp = HMAC(initiator.chaining_key, DH(initiator.static_private, responder.static_public)) + let temp = b2s_hmac(&chaining_key, self.params.static_shared.as_bytes()); + // initiator.chaining_key = HMAC(temp, 0x1) + chaining_key = b2s_hmac(&temp, &[0x01]); + // key = HMAC(temp, initiator.chaining_key || 0x2) + let key = b2s_hmac2(&temp, &chaining_key, &[0x02]); + // msg.encrypted_timestamp = AEAD(key, 0, TAI64N(), initiator.hash) + let timestamp = self.stamper.stamp(); + aead_chacha20_seal(encrypted_timestamp, &key, 0, ×tamp, &hash); + // initiator.hash = HASH(initiator.hash || msg.encrypted_timestamp) + hash = b2s_hash(&hash, encrypted_timestamp); + + let time_now = Instant::now(); + self.previous = std::mem::replace( + &mut self.state, + HandshakeState::InitSent(HandshakeInitSentState { + local_index, + chaining_key, + hash, + ephemeral_private, + time_sent: time_now, + }), + ); + + self.append_mac1_and_mac2(local_index, &mut dst[..super::HANDSHAKE_INIT_SZ]) + } + + fn format_handshake_response<'a>( + &mut self, + dst: &'a mut [u8], + ) -> Result<(&'a mut [u8], Session), WireGuardError> { + if dst.len() < super::HANDSHAKE_RESP_SZ { + return Err(WireGuardError::DestinationBufferTooSmall); + } + + let state = std::mem::replace(&mut self.state, HandshakeState::None); + let (mut chaining_key, mut hash, peer_ephemeral_public, peer_index) = match state { + HandshakeState::InitReceived { + chaining_key, + hash, + peer_ephemeral_public, + peer_index, + } => (chaining_key, hash, peer_ephemeral_public, peer_index), + _ => { + panic!("Unexpected attempt to call send_handshake_response"); + } + }; + + let (message_type, rest) = dst.split_at_mut(4); + let (sender_index, rest) = rest.split_at_mut(4); + let (receiver_index, rest) = rest.split_at_mut(4); + let (unencrypted_ephemeral, rest) = rest.split_at_mut(32); + let (encrypted_nothing, _) = rest.split_at_mut(16); + + // responder.ephemeral_private = DH_GENERATE() + let ephemeral_private = x25519::ReusableSecret::random_from_rng(OsRng); + let local_index = self.inc_index(); + // msg.message_type = 2 + // msg.reserved_zero = { 0, 0, 0 } + message_type.copy_from_slice(&super::HANDSHAKE_RESP.to_le_bytes()); + // msg.sender_index = little_endian(responder.sender_index) + sender_index.copy_from_slice(&local_index.to_le_bytes()); + // msg.receiver_index = little_endian(initiator.sender_index) + receiver_index.copy_from_slice(&peer_index.to_le_bytes()); + // msg.unencrypted_ephemeral = DH_PUBKEY(initiator.ephemeral_private) + unencrypted_ephemeral + .copy_from_slice(x25519::PublicKey::from(&ephemeral_private).as_bytes()); + // responder.hash = HASH(responder.hash || msg.unencrypted_ephemeral) + hash = b2s_hash(&hash, unencrypted_ephemeral); + // temp = HMAC(responder.chaining_key, msg.unencrypted_ephemeral) + let temp = b2s_hmac(&chaining_key, unencrypted_ephemeral); + // responder.chaining_key = HMAC(temp, 0x1) + chaining_key = b2s_hmac(&temp, &[0x01]); + // temp = HMAC(responder.chaining_key, DH(responder.ephemeral_private, initiator.ephemeral_public)) + let ephemeral_shared = ephemeral_private.diffie_hellman(&peer_ephemeral_public); + let temp = b2s_hmac(&chaining_key, &ephemeral_shared.to_bytes()); + // responder.chaining_key = HMAC(temp, 0x1) + chaining_key = b2s_hmac(&temp, &[0x01]); + // temp = HMAC(responder.chaining_key, DH(responder.ephemeral_private, initiator.static_public)) + let temp = b2s_hmac( + &chaining_key, + &ephemeral_private + .diffie_hellman(&self.params.peer_static_public) + .to_bytes(), + ); + // responder.chaining_key = HMAC(temp, 0x1) + chaining_key = b2s_hmac(&temp, &[0x01]); + // temp = HMAC(responder.chaining_key, preshared_key) + let temp = b2s_hmac( + &chaining_key, + &self.params.preshared_key.unwrap_or([0u8; 32])[..], + ); + // responder.chaining_key = HMAC(temp, 0x1) + chaining_key = b2s_hmac(&temp, &[0x01]); + // temp2 = HMAC(temp, responder.chaining_key || 0x2) + let temp2 = b2s_hmac2(&temp, &chaining_key, &[0x02]); + // key = HMAC(temp, temp2 || 0x3) + let key = b2s_hmac2(&temp, &temp2, &[0x03]); + // responder.hash = HASH(responder.hash || temp2) + hash = b2s_hash(&hash, &temp2); + // msg.encrypted_nothing = AEAD(key, 0, [empty], responder.hash) + aead_chacha20_seal(encrypted_nothing, &key, 0, &[], &hash); + + // Derive keys + // temp1 = HMAC(initiator.chaining_key, [empty]) + // temp2 = HMAC(temp1, 0x1) + // temp3 = HMAC(temp1, temp2 || 0x2) + // initiator.sending_key = temp2 + // initiator.receiving_key = temp3 + // initiator.sending_key_counter = 0 + // initiator.receiving_key_counter = 0 + let temp1 = b2s_hmac(&chaining_key, &[]); + let temp2 = b2s_hmac(&temp1, &[0x01]); + let temp3 = b2s_hmac2(&temp1, &temp2, &[0x02]); + + let dst = self.append_mac1_and_mac2(local_index, &mut dst[..super::HANDSHAKE_RESP_SZ])?; + + Ok((dst, Session::new(local_index, peer_index, temp2, temp3))) + } +} + +#[cfg(test)] +mod tests { + use super::*; + + #[test] + fn chacha20_seal_rfc7530_test_vector() { + let plaintext = b"Ladies and Gentlemen of the class of '99: If I could offer you only one tip for the future, sunscreen would be it."; + let aad: [u8; 12] = [ + 0x50, 0x51, 0x52, 0x53, 0xc0, 0xc1, 0xc2, 0xc3, 0xc4, 0xc5, 0xc6, 0xc7, + ]; + let key: [u8; 32] = [ + 0x80, 0x81, 0x82, 0x83, 0x84, 0x85, 0x86, 0x87, 0x88, 0x89, 0x8a, 0x8b, 0x8c, 0x8d, + 0x8e, 0x8f, 0x90, 0x91, 0x92, 0x93, 0x94, 0x95, 0x96, 0x97, 0x98, 0x99, 0x9a, 0x9b, + 0x9c, 0x9d, 0x9e, 0x9f, + ]; + let nonce: [u8; 12] = [ + 0x07, 0x00, 0x00, 0x00, 0x40, 0x41, 0x42, 0x43, 0x44, 0x45, 0x46, 0x47, + ]; + let mut buffer = vec![0; plaintext.len() + 16]; + + aead_chacha20_seal_inner(&mut buffer, &key, nonce, plaintext, &aad); + + const EXPECTED_CIPHERTEXT: [u8; 114] = [ + 0xd3, 0x1a, 0x8d, 0x34, 0x64, 0x8e, 0x60, 0xdb, 0x7b, 0x86, 0xaf, 0xbc, 0x53, 0xef, + 0x7e, 0xc2, 0xa4, 0xad, 0xed, 0x51, 0x29, 0x6e, 0x08, 0xfe, 0xa9, 0xe2, 0xb5, 0xa7, + 0x36, 0xee, 0x62, 0xd6, 0x3d, 0xbe, 0xa4, 0x5e, 0x8c, 0xa9, 0x67, 0x12, 0x82, 0xfa, + 0xfb, 0x69, 0xda, 0x92, 0x72, 0x8b, 0x1a, 0x71, 0xde, 0x0a, 0x9e, 0x06, 0x0b, 0x29, + 0x05, 0xd6, 0xa5, 0xb6, 0x7e, 0xcd, 0x3b, 0x36, 0x92, 0xdd, 0xbd, 0x7f, 0x2d, 0x77, + 0x8b, 0x8c, 0x98, 0x03, 0xae, 0xe3, 0x28, 0x09, 0x1b, 0x58, 0xfa, 0xb3, 0x24, 0xe4, + 0xfa, 0xd6, 0x75, 0x94, 0x55, 0x85, 0x80, 0x8b, 0x48, 0x31, 0xd7, 0xbc, 0x3f, 0xf4, + 0xde, 0xf0, 0x8e, 0x4b, 0x7a, 0x9d, 0xe5, 0x76, 0xd2, 0x65, 0x86, 0xce, 0xc6, 0x4b, + 0x61, 0x16, + ]; + const EXPECTED_TAG: [u8; 16] = [ + 0x1a, 0xe1, 0x0b, 0x59, 0x4f, 0x09, 0xe2, 0x6a, 0x7e, 0x90, 0x2e, 0xcb, 0xd0, 0x60, + 0x06, 0x91, + ]; + + assert_eq!(buffer[..plaintext.len()], EXPECTED_CIPHERTEXT); + assert_eq!(buffer[plaintext.len()..], EXPECTED_TAG); + } + + #[test] + fn symmetric_chacha20_seal_open() { + let aad: [u8; 32] = Default::default(); + let key: [u8; 32] = Default::default(); + let counter = 0; + + let mut encrypted_nothing: [u8; 16] = Default::default(); + + aead_chacha20_seal(&mut encrypted_nothing, &key, counter, &[], &aad); + + eprintln!("encrypted_nothing: {:?}", encrypted_nothing); + + aead_chacha20_open(&mut [], &key, counter, &encrypted_nothing, &aad) + .expect("Should open what we just sealed"); + } +} diff --git a/burrow/src/boringtun/src/noise/mod.rs b/burrow/src/boringtun/src/noise/mod.rs new file mode 100755 index 0000000..79a6b92 --- /dev/null +++ b/burrow/src/boringtun/src/noise/mod.rs @@ -0,0 +1,799 @@ +// Copyright (c) 2019 Cloudflare, Inc. All rights reserved. +// SPDX-License-Identifier: BSD-3-Clause + +pub mod errors; +pub mod handshake; +pub mod rate_limiter; + +mod session; +mod timers; + +use crate::noise::errors::WireGuardError; +use crate::noise::handshake::Handshake; +use crate::noise::rate_limiter::RateLimiter; +use crate::noise::timers::{TimerName, Timers}; +use crate::x25519; + +use std::collections::VecDeque; +use std::convert::{TryFrom, TryInto}; +use std::net::{IpAddr, Ipv4Addr, Ipv6Addr}; +use std::sync::Arc; +use std::time::Duration; + +/// The default value to use for rate limiting, when no other rate limiter is defined +const PEER_HANDSHAKE_RATE_LIMIT: u64 = 10; + +const IPV4_MIN_HEADER_SIZE: usize = 20; +const IPV4_LEN_OFF: usize = 2; +const IPV4_SRC_IP_OFF: usize = 12; +const IPV4_DST_IP_OFF: usize = 16; +const IPV4_IP_SZ: usize = 4; + +const IPV6_MIN_HEADER_SIZE: usize = 40; +const IPV6_LEN_OFF: usize = 4; +const IPV6_SRC_IP_OFF: usize = 8; +const IPV6_DST_IP_OFF: usize = 24; +const IPV6_IP_SZ: usize = 16; + +const IP_LEN_SZ: usize = 2; + +const MAX_QUEUE_DEPTH: usize = 256; +/// number of sessions in the ring, better keep a PoT +const N_SESSIONS: usize = 8; + +#[derive(Debug)] +pub enum TunnResult<'a> { + Done, + Err(WireGuardError), + WriteToNetwork(&'a mut [u8]), + WriteToTunnelV4(&'a mut [u8], Ipv4Addr), + WriteToTunnelV6(&'a mut [u8], Ipv6Addr), +} + +impl<'a> From for TunnResult<'a> { + fn from(err: WireGuardError) -> TunnResult<'a> { + TunnResult::Err(err) + } +} + +/// Tunnel represents a point-to-point WireGuard connection +pub struct Tunn { + /// The handshake currently in progress + handshake: handshake::Handshake, + /// The N_SESSIONS most recent sessions, index is session id modulo N_SESSIONS + sessions: [Option; N_SESSIONS], + /// Index of most recently used session + current: usize, + /// Queue to store blocked packets + packet_queue: VecDeque>, + /// Keeps tabs on the expiring timers + timers: timers::Timers, + tx_bytes: usize, + rx_bytes: usize, + rate_limiter: Arc, +} + +type MessageType = u32; +const HANDSHAKE_INIT: MessageType = 1; +const HANDSHAKE_RESP: MessageType = 2; +const COOKIE_REPLY: MessageType = 3; +const DATA: MessageType = 4; + +const HANDSHAKE_INIT_SZ: usize = 148; +const HANDSHAKE_RESP_SZ: usize = 92; +const COOKIE_REPLY_SZ: usize = 64; +const DATA_OVERHEAD_SZ: usize = 32; + +#[derive(Debug)] +pub struct HandshakeInit<'a> { + sender_idx: u32, + unencrypted_ephemeral: &'a [u8; 32], + encrypted_static: &'a [u8], + encrypted_timestamp: &'a [u8], +} + +#[derive(Debug)] +pub struct HandshakeResponse<'a> { + sender_idx: u32, + pub receiver_idx: u32, + unencrypted_ephemeral: &'a [u8; 32], + encrypted_nothing: &'a [u8], +} + +#[derive(Debug)] +pub struct PacketCookieReply<'a> { + pub receiver_idx: u32, + nonce: &'a [u8], + encrypted_cookie: &'a [u8], +} + +#[derive(Debug)] +pub struct PacketData<'a> { + pub receiver_idx: u32, + counter: u64, + encrypted_encapsulated_packet: &'a [u8], +} + +/// Describes a packet from network +#[derive(Debug)] +pub enum Packet<'a> { + HandshakeInit(HandshakeInit<'a>), + HandshakeResponse(HandshakeResponse<'a>), + PacketCookieReply(PacketCookieReply<'a>), + PacketData(PacketData<'a>), +} + +impl Tunn { + #[inline(always)] + pub fn parse_incoming_packet(src: &[u8]) -> Result { + if src.len() < 4 { + return Err(WireGuardError::InvalidPacket); + } + + // Checks the type, as well as the reserved zero fields + let packet_type = u32::from_le_bytes(src[0..4].try_into().unwrap()); + + Ok(match (packet_type, src.len()) { + (HANDSHAKE_INIT, HANDSHAKE_INIT_SZ) => Packet::HandshakeInit(HandshakeInit { + sender_idx: u32::from_le_bytes(src[4..8].try_into().unwrap()), + unencrypted_ephemeral: <&[u8; 32] as TryFrom<&[u8]>>::try_from(&src[8..40]) + .expect("length already checked above"), + encrypted_static: &src[40..88], + encrypted_timestamp: &src[88..116], + }), + (HANDSHAKE_RESP, HANDSHAKE_RESP_SZ) => Packet::HandshakeResponse(HandshakeResponse { + sender_idx: u32::from_le_bytes(src[4..8].try_into().unwrap()), + receiver_idx: u32::from_le_bytes(src[8..12].try_into().unwrap()), + unencrypted_ephemeral: <&[u8; 32] as TryFrom<&[u8]>>::try_from(&src[12..44]) + .expect("length already checked above"), + encrypted_nothing: &src[44..60], + }), + (COOKIE_REPLY, COOKIE_REPLY_SZ) => Packet::PacketCookieReply(PacketCookieReply { + receiver_idx: u32::from_le_bytes(src[4..8].try_into().unwrap()), + nonce: &src[8..32], + encrypted_cookie: &src[32..64], + }), + (DATA, DATA_OVERHEAD_SZ..=std::usize::MAX) => Packet::PacketData(PacketData { + receiver_idx: u32::from_le_bytes(src[4..8].try_into().unwrap()), + counter: u64::from_le_bytes(src[8..16].try_into().unwrap()), + encrypted_encapsulated_packet: &src[16..], + }), + _ => return Err(WireGuardError::InvalidPacket), + }) + } + + pub fn is_expired(&self) -> bool { + self.handshake.is_expired() + } + + pub fn dst_address(packet: &[u8]) -> Option { + if packet.is_empty() { + return None; + } + + match packet[0] >> 4 { + 4 if packet.len() >= IPV4_MIN_HEADER_SIZE => { + let addr_bytes: [u8; IPV4_IP_SZ] = packet + [IPV4_DST_IP_OFF..IPV4_DST_IP_OFF + IPV4_IP_SZ] + .try_into() + .unwrap(); + Some(IpAddr::from(addr_bytes)) + } + 6 if packet.len() >= IPV6_MIN_HEADER_SIZE => { + let addr_bytes: [u8; IPV6_IP_SZ] = packet + [IPV6_DST_IP_OFF..IPV6_DST_IP_OFF + IPV6_IP_SZ] + .try_into() + .unwrap(); + Some(IpAddr::from(addr_bytes)) + } + _ => None, + } + } + + /// Create a new tunnel using own private key and the peer public key + pub fn new( + static_private: x25519::StaticSecret, + peer_static_public: x25519::PublicKey, + preshared_key: Option<[u8; 32]>, + persistent_keepalive: Option, + index: u32, + rate_limiter: Option>, + ) -> Result { + let static_public = x25519::PublicKey::from(&static_private); + + let tunn = Tunn { + handshake: Handshake::new( + static_private, + static_public, + peer_static_public, + index << 8, + preshared_key, + ) + .map_err(|_| "Invalid parameters")?, + sessions: Default::default(), + current: Default::default(), + tx_bytes: Default::default(), + rx_bytes: Default::default(), + + packet_queue: VecDeque::new(), + timers: Timers::new(persistent_keepalive, rate_limiter.is_none()), + + rate_limiter: rate_limiter.unwrap_or_else(|| { + Arc::new(RateLimiter::new(&static_public, PEER_HANDSHAKE_RATE_LIMIT)) + }), + }; + + Ok(tunn) + } + + /// Update the private key and clear existing sessions + pub fn set_static_private( + &mut self, + static_private: x25519::StaticSecret, + static_public: x25519::PublicKey, + rate_limiter: Option>, + ) -> Result<(), WireGuardError> { + self.timers.should_reset_rr = rate_limiter.is_none(); + self.rate_limiter = rate_limiter.unwrap_or_else(|| { + Arc::new(RateLimiter::new(&static_public, PEER_HANDSHAKE_RATE_LIMIT)) + }); + self.handshake + .set_static_private(static_private, static_public)?; + for s in &mut self.sessions { + *s = None; + } + Ok(()) + } + + /// Encapsulate a single packet from the tunnel interface. + /// Returns TunnResult. + /// + /// # Panics + /// Panics if dst buffer is too small. + /// Size of dst should be at least src.len() + 32, and no less than 148 bytes. + pub fn encapsulate<'a>(&mut self, src: &[u8], dst: &'a mut [u8]) -> TunnResult<'a> { + let current = self.current; + if let Some(ref session) = self.sessions[current % N_SESSIONS] { + // Send the packet using an established session + let packet = session.format_packet_data(src, dst); + self.timer_tick(TimerName::TimeLastPacketSent); + // Exclude Keepalive packets from timer update. + if !src.is_empty() { + self.timer_tick(TimerName::TimeLastDataPacketSent); + } + self.tx_bytes += src.len(); + return TunnResult::WriteToNetwork(packet); + } + + // If there is no session, queue the packet for future retry + self.queue_packet(src); + // Initiate a new handshake if none is in progress + self.format_handshake_initiation(dst, false) + } + + /// Receives a UDP datagram from the network and parses it. + /// Returns TunnResult. + /// + /// If the result is of type TunnResult::WriteToNetwork, should repeat the call with empty datagram, + /// until TunnResult::Done is returned. If batch processing packets, it is OK to defer until last + /// packet is processed. + pub fn decapsulate<'a>( + &mut self, + src_addr: Option, + datagram: &[u8], + dst: &'a mut [u8], + ) -> TunnResult<'a> { + if datagram.is_empty() { + // Indicates a repeated call + return self.send_queued_packet(dst); + } + + let mut cookie = [0u8; COOKIE_REPLY_SZ]; + let packet = match self + .rate_limiter + .verify_packet(src_addr, datagram, &mut cookie) + { + Ok(packet) => packet, + Err(TunnResult::WriteToNetwork(cookie)) => { + dst[..cookie.len()].copy_from_slice(cookie); + return TunnResult::WriteToNetwork(&mut dst[..cookie.len()]); + } + Err(TunnResult::Err(e)) => return TunnResult::Err(e), + _ => unreachable!(), + }; + + self.handle_verified_packet(packet, dst) + } + + pub(crate) fn handle_verified_packet<'a>( + &mut self, + packet: Packet, + dst: &'a mut [u8], + ) -> TunnResult<'a> { + match packet { + Packet::HandshakeInit(p) => self.handle_handshake_init(p, dst), + Packet::HandshakeResponse(p) => self.handle_handshake_response(p, dst), + Packet::PacketCookieReply(p) => self.handle_cookie_reply(p), + Packet::PacketData(p) => self.handle_data(p, dst), + } + .unwrap_or_else(TunnResult::from) + } + + fn handle_handshake_init<'a>( + &mut self, + p: HandshakeInit, + dst: &'a mut [u8], + ) -> Result, WireGuardError> { + tracing::debug!( + message = "Received handshake_initiation", + remote_idx = p.sender_idx + ); + + let (packet, session) = self.handshake.receive_handshake_initialization(p, dst)?; + + // Store new session in ring buffer + let index = session.local_index(); + self.sessions[index % N_SESSIONS] = Some(session); + + self.timer_tick(TimerName::TimeLastPacketReceived); + self.timer_tick(TimerName::TimeLastPacketSent); + self.timer_tick_session_established(false, index); // New session established, we are not the initiator + + tracing::debug!(message = "Sending handshake_response", local_idx = index); + + Ok(TunnResult::WriteToNetwork(packet)) + } + + fn handle_handshake_response<'a>( + &mut self, + p: HandshakeResponse, + dst: &'a mut [u8], + ) -> Result, WireGuardError> { + tracing::debug!( + message = "Received handshake_response", + local_idx = p.receiver_idx, + remote_idx = p.sender_idx + ); + + let session = self.handshake.receive_handshake_response(p)?; + + let keepalive_packet = session.format_packet_data(&[], dst); + // Store new session in ring buffer + let l_idx = session.local_index(); + let index = l_idx % N_SESSIONS; + self.sessions[index] = Some(session); + + self.timer_tick(TimerName::TimeLastPacketReceived); + self.timer_tick_session_established(true, index); // New session established, we are the initiator + self.set_current_session(l_idx); + + tracing::debug!("Sending keepalive"); + + Ok(TunnResult::WriteToNetwork(keepalive_packet)) // Send a keepalive as a response + } + + fn handle_cookie_reply<'a>( + &mut self, + p: PacketCookieReply, + ) -> Result, WireGuardError> { + tracing::debug!( + message = "Received cookie_reply", + local_idx = p.receiver_idx + ); + + self.handshake.receive_cookie_reply(p)?; + self.timer_tick(TimerName::TimeLastPacketReceived); + self.timer_tick(TimerName::TimeCookieReceived); + + tracing::debug!("Did set cookie"); + + Ok(TunnResult::Done) + } + + /// Update the index of the currently used session, if needed + fn set_current_session(&mut self, new_idx: usize) { + let cur_idx = self.current; + if cur_idx == new_idx { + // There is nothing to do, already using this session, this is the common case + return; + } + if self.sessions[cur_idx % N_SESSIONS].is_none() + || self.timers.session_timers[new_idx % N_SESSIONS] + >= self.timers.session_timers[cur_idx % N_SESSIONS] + { + self.current = new_idx; + tracing::debug!(message = "New session", session = new_idx); + } + } + + /// Decrypts a data packet, and stores the decapsulated packet in dst. + fn handle_data<'a>( + &mut self, + packet: PacketData, + dst: &'a mut [u8], + ) -> Result, WireGuardError> { + let r_idx = packet.receiver_idx as usize; + let idx = r_idx % N_SESSIONS; + + // Get the (probably) right session + let decapsulated_packet = { + let session = self.sessions[idx].as_ref(); + let session = session.ok_or_else(|| { + tracing::trace!(message = "No current session available", remote_idx = r_idx); + WireGuardError::NoCurrentSession + })?; + session.receive_packet_data(packet, dst)? + }; + + self.set_current_session(r_idx); + + self.timer_tick(TimerName::TimeLastPacketReceived); + + Ok(self.validate_decapsulated_packet(decapsulated_packet)) + } + + /// Formats a new handshake initiation message and store it in dst. If force_resend is true will send + /// a new handshake, even if a handshake is already in progress (for example when a handshake times out) + pub fn format_handshake_initiation<'a>( + &mut self, + dst: &'a mut [u8], + force_resend: bool, + ) -> TunnResult<'a> { + if self.handshake.is_in_progress() && !force_resend { + return TunnResult::Done; + } + + if self.handshake.is_expired() { + self.timers.clear(); + } + + let starting_new_handshake = !self.handshake.is_in_progress(); + + match self.handshake.format_handshake_initiation(dst) { + Ok(packet) => { + tracing::debug!("Sending handshake_initiation"); + + if starting_new_handshake { + self.timer_tick(TimerName::TimeLastHandshakeStarted); + } + self.timer_tick(TimerName::TimeLastPacketSent); + TunnResult::WriteToNetwork(packet) + } + Err(e) => TunnResult::Err(e), + } + } + + /// Check if an IP packet is v4 or v6, truncate to the length indicated by the length field + /// Returns the truncated packet and the source IP as TunnResult + fn validate_decapsulated_packet<'a>(&mut self, packet: &'a mut [u8]) -> TunnResult<'a> { + let (computed_len, src_ip_address) = match packet.len() { + 0 => return TunnResult::Done, // This is keepalive, and not an error + _ if packet[0] >> 4 == 4 && packet.len() >= IPV4_MIN_HEADER_SIZE => { + let len_bytes: [u8; IP_LEN_SZ] = packet[IPV4_LEN_OFF..IPV4_LEN_OFF + IP_LEN_SZ] + .try_into() + .unwrap(); + let addr_bytes: [u8; IPV4_IP_SZ] = packet + [IPV4_SRC_IP_OFF..IPV4_SRC_IP_OFF + IPV4_IP_SZ] + .try_into() + .unwrap(); + ( + u16::from_be_bytes(len_bytes) as usize, + IpAddr::from(addr_bytes), + ) + } + _ if packet[0] >> 4 == 6 && packet.len() >= IPV6_MIN_HEADER_SIZE => { + let len_bytes: [u8; IP_LEN_SZ] = packet[IPV6_LEN_OFF..IPV6_LEN_OFF + IP_LEN_SZ] + .try_into() + .unwrap(); + let addr_bytes: [u8; IPV6_IP_SZ] = packet + [IPV6_SRC_IP_OFF..IPV6_SRC_IP_OFF + IPV6_IP_SZ] + .try_into() + .unwrap(); + ( + u16::from_be_bytes(len_bytes) as usize + IPV6_MIN_HEADER_SIZE, + IpAddr::from(addr_bytes), + ) + } + _ => return TunnResult::Err(WireGuardError::InvalidPacket), + }; + + if computed_len > packet.len() { + return TunnResult::Err(WireGuardError::InvalidPacket); + } + + self.timer_tick(TimerName::TimeLastDataPacketReceived); + self.rx_bytes += computed_len; + + match src_ip_address { + IpAddr::V4(addr) => TunnResult::WriteToTunnelV4(&mut packet[..computed_len], addr), + IpAddr::V6(addr) => TunnResult::WriteToTunnelV6(&mut packet[..computed_len], addr), + } + } + + /// Get a packet from the queue, and try to encapsulate it + fn send_queued_packet<'a>(&mut self, dst: &'a mut [u8]) -> TunnResult<'a> { + if let Some(packet) = self.dequeue_packet() { + match self.encapsulate(&packet, dst) { + TunnResult::Err(_) => { + // On error, return packet to the queue + self.requeue_packet(packet); + } + r => return r, + } + } + TunnResult::Done + } + + /// Push packet to the back of the queue + fn queue_packet(&mut self, packet: &[u8]) { + if self.packet_queue.len() < MAX_QUEUE_DEPTH { + // Drop if too many are already in queue + self.packet_queue.push_back(packet.to_vec()); + } + } + + /// Push packet to the front of the queue + fn requeue_packet(&mut self, packet: Vec) { + if self.packet_queue.len() < MAX_QUEUE_DEPTH { + // Drop if too many are already in queue + self.packet_queue.push_front(packet); + } + } + + fn dequeue_packet(&mut self) -> Option> { + self.packet_queue.pop_front() + } + + fn estimate_loss(&self) -> f32 { + let session_idx = self.current; + + let mut weight = 9.0; + let mut cur_avg = 0.0; + let mut total_weight = 0.0; + + for i in 0..N_SESSIONS { + if let Some(ref session) = self.sessions[(session_idx.wrapping_sub(i)) % N_SESSIONS] { + let (expected, received) = session.current_packet_cnt(); + + let loss = if expected == 0 { + 0.0 + } else { + 1.0 - received as f32 / expected as f32 + }; + + cur_avg += loss * weight; + total_weight += weight; + weight /= 3.0; + } + } + + if total_weight == 0.0 { + 0.0 + } else { + cur_avg / total_weight + } + } + + /// Return stats from the tunnel: + /// * Time since last handshake in seconds + /// * Data bytes sent + /// * Data bytes received + pub fn stats(&self) -> (Option, usize, usize, f32, Option) { + let time = self.time_since_last_handshake(); + let tx_bytes = self.tx_bytes; + let rx_bytes = self.rx_bytes; + let loss = self.estimate_loss(); + let rtt = self.handshake.last_rtt; + + (time, tx_bytes, rx_bytes, loss, rtt) + } +} + +#[cfg(test)] +mod tests { + #[cfg(feature = "mock-instant")] + use crate::noise::timers::{REKEY_AFTER_TIME, REKEY_TIMEOUT}; + + use super::*; + use rand_core::{OsRng, RngCore}; + + fn create_two_tuns() -> (Tunn, Tunn) { + let my_secret_key = x25519_dalek::StaticSecret::random_from_rng(OsRng); + let my_public_key = x25519_dalek::PublicKey::from(&my_secret_key); + let my_idx = OsRng.next_u32(); + + let their_secret_key = x25519_dalek::StaticSecret::random_from_rng(OsRng); + let their_public_key = x25519_dalek::PublicKey::from(&their_secret_key); + let their_idx = OsRng.next_u32(); + + let my_tun = Tunn::new(my_secret_key, their_public_key, None, None, my_idx, None).unwrap(); + + let their_tun = + Tunn::new(their_secret_key, my_public_key, None, None, their_idx, None).unwrap(); + + (my_tun, their_tun) + } + + fn create_handshake_init(tun: &mut Tunn) -> Vec { + let mut dst = vec![0u8; 2048]; + let handshake_init = tun.format_handshake_initiation(&mut dst, false); + assert!(matches!(handshake_init, TunnResult::WriteToNetwork(_))); + let handshake_init = if let TunnResult::WriteToNetwork(sent) = handshake_init { + sent + } else { + unreachable!(); + }; + + handshake_init.into() + } + + fn create_handshake_response(tun: &mut Tunn, handshake_init: &[u8]) -> Vec { + let mut dst = vec![0u8; 2048]; + let handshake_resp = tun.decapsulate(None, handshake_init, &mut dst); + assert!(matches!(handshake_resp, TunnResult::WriteToNetwork(_))); + + let handshake_resp = if let TunnResult::WriteToNetwork(sent) = handshake_resp { + sent + } else { + unreachable!(); + }; + + handshake_resp.into() + } + + fn parse_handshake_resp(tun: &mut Tunn, handshake_resp: &[u8]) -> Vec { + let mut dst = vec![0u8; 2048]; + let keepalive = tun.decapsulate(None, handshake_resp, &mut dst); + assert!(matches!(keepalive, TunnResult::WriteToNetwork(_))); + + let keepalive = if let TunnResult::WriteToNetwork(sent) = keepalive { + sent + } else { + unreachable!(); + }; + + keepalive.into() + } + + fn parse_keepalive(tun: &mut Tunn, keepalive: &[u8]) { + let mut dst = vec![0u8; 2048]; + let keepalive = tun.decapsulate(None, keepalive, &mut dst); + assert!(matches!(keepalive, TunnResult::Done)); + } + + fn create_two_tuns_and_handshake() -> (Tunn, Tunn) { + let (mut my_tun, mut their_tun) = create_two_tuns(); + let init = create_handshake_init(&mut my_tun); + let resp = create_handshake_response(&mut their_tun, &init); + let keepalive = parse_handshake_resp(&mut my_tun, &resp); + parse_keepalive(&mut their_tun, &keepalive); + + (my_tun, their_tun) + } + + fn create_ipv4_udp_packet() -> Vec { + let header = + etherparse::PacketBuilder::ipv4([192, 168, 1, 2], [192, 168, 1, 3], 5).udp(5678, 23); + let payload = [0, 1, 2, 3]; + let mut packet = Vec::::with_capacity(header.size(payload.len())); + header.write(&mut packet, &payload).unwrap(); + packet + } + + #[cfg(feature = "mock-instant")] + fn update_timer_results_in_handshake(tun: &mut Tunn) { + let mut dst = vec![0u8; 2048]; + let result = tun.update_timers(&mut dst); + assert!(matches!(result, TunnResult::WriteToNetwork(_))); + let packet_data = if let TunnResult::WriteToNetwork(data) = result { + data + } else { + unreachable!(); + }; + let packet = Tunn::parse_incoming_packet(packet_data).unwrap(); + assert!(matches!(packet, Packet::HandshakeInit(_))); + } + + #[test] + fn create_two_tunnels_linked_to_eachother() { + let (_my_tun, _their_tun) = create_two_tuns(); + } + + #[test] + fn handshake_init() { + let (mut my_tun, _their_tun) = create_two_tuns(); + let init = create_handshake_init(&mut my_tun); + let packet = Tunn::parse_incoming_packet(&init).unwrap(); + assert!(matches!(packet, Packet::HandshakeInit(_))); + } + + #[test] + fn handshake_init_and_response() { + let (mut my_tun, mut their_tun) = create_two_tuns(); + let init = create_handshake_init(&mut my_tun); + let resp = create_handshake_response(&mut their_tun, &init); + let packet = Tunn::parse_incoming_packet(&resp).unwrap(); + assert!(matches!(packet, Packet::HandshakeResponse(_))); + } + + #[test] + fn full_handshake() { + let (mut my_tun, mut their_tun) = create_two_tuns(); + let init = create_handshake_init(&mut my_tun); + let resp = create_handshake_response(&mut their_tun, &init); + let keepalive = parse_handshake_resp(&mut my_tun, &resp); + let packet = Tunn::parse_incoming_packet(&keepalive).unwrap(); + assert!(matches!(packet, Packet::PacketData(_))); + } + + #[test] + fn full_handshake_plus_timers() { + let (mut my_tun, mut their_tun) = create_two_tuns_and_handshake(); + // Time has not yet advanced so their is nothing to do + assert!(matches!(my_tun.update_timers(&mut []), TunnResult::Done)); + assert!(matches!(their_tun.update_timers(&mut []), TunnResult::Done)); + } + + #[test] + #[cfg(feature = "mock-instant")] + fn new_handshake_after_two_mins() { + let (mut my_tun, mut their_tun) = create_two_tuns_and_handshake(); + let mut my_dst = [0u8; 1024]; + + // Advance time 1 second and "send" 1 packet so that we send a handshake + // after the timeout + mock_instant::MockClock::advance(Duration::from_secs(1)); + assert!(matches!(their_tun.update_timers(&mut []), TunnResult::Done)); + assert!(matches!( + my_tun.update_timers(&mut my_dst), + TunnResult::Done + )); + let sent_packet_buf = create_ipv4_udp_packet(); + let data = my_tun.encapsulate(&sent_packet_buf, &mut my_dst); + assert!(matches!(data, TunnResult::WriteToNetwork(_))); + + //Advance to timeout + mock_instant::MockClock::advance(REKEY_AFTER_TIME); + assert!(matches!(their_tun.update_timers(&mut []), TunnResult::Done)); + update_timer_results_in_handshake(&mut my_tun); + } + + #[test] + #[cfg(feature = "mock-instant")] + fn handshake_no_resp_rekey_timeout() { + let (mut my_tun, _their_tun) = create_two_tuns(); + + let init = create_handshake_init(&mut my_tun); + let packet = Tunn::parse_incoming_packet(&init).unwrap(); + assert!(matches!(packet, Packet::HandshakeInit(_))); + + mock_instant::MockClock::advance(REKEY_TIMEOUT); + update_timer_results_in_handshake(&mut my_tun) + } + + #[test] + fn one_ip_packet() { + let (mut my_tun, mut their_tun) = create_two_tuns_and_handshake(); + let mut my_dst = [0u8; 1024]; + let mut their_dst = [0u8; 1024]; + + let sent_packet_buf = create_ipv4_udp_packet(); + + let data = my_tun.encapsulate(&sent_packet_buf, &mut my_dst); + assert!(matches!(data, TunnResult::WriteToNetwork(_))); + let data = if let TunnResult::WriteToNetwork(sent) = data { + sent + } else { + unreachable!(); + }; + + let data = their_tun.decapsulate(None, data, &mut their_dst); + assert!(matches!(data, TunnResult::WriteToTunnelV4(..))); + let recv_packet_buf = if let TunnResult::WriteToTunnelV4(recv, _addr) = data { + recv + } else { + unreachable!(); + }; + assert_eq!(sent_packet_buf, recv_packet_buf); + } +} diff --git a/burrow/src/boringtun/src/noise/rate_limiter.rs b/burrow/src/boringtun/src/noise/rate_limiter.rs new file mode 100755 index 0000000..052cbb3 --- /dev/null +++ b/burrow/src/boringtun/src/noise/rate_limiter.rs @@ -0,0 +1,193 @@ +use super::handshake::{b2s_hash, b2s_keyed_mac_16, b2s_keyed_mac_16_2, b2s_mac_24}; +use crate::noise::handshake::{LABEL_COOKIE, LABEL_MAC1}; +use crate::noise::{HandshakeInit, HandshakeResponse, Packet, Tunn, TunnResult, WireGuardError}; + +#[cfg(feature = "mock-instant")] +use mock_instant::Instant; +use std::net::IpAddr; +use std::sync::atomic::{AtomicU64, Ordering}; + +#[cfg(not(feature = "mock-instant"))] +use crate::sleepyinstant::Instant; + +use aead::generic_array::GenericArray; +use aead::{AeadInPlace, KeyInit}; +use chacha20poly1305::{Key, XChaCha20Poly1305}; +use parking_lot::Mutex; +use rand_core::{OsRng, RngCore}; +use ring::constant_time::verify_slices_are_equal; + +const COOKIE_REFRESH: u64 = 128; // Use 128 and not 120 so the compiler can optimize out the division +const COOKIE_SIZE: usize = 16; +const COOKIE_NONCE_SIZE: usize = 24; + +/// How often should reset count in seconds +const RESET_PERIOD: u64 = 1; + +type Cookie = [u8; COOKIE_SIZE]; + +/// There are two places where WireGuard requires "randomness" for cookies +/// * The 24 byte nonce in the cookie massage - here the only goal is to avoid nonce reuse +/// * A secret value that changes every two minutes +/// Because the main goal of the cookie is simply for a party to prove ownership of an IP address +/// we can relax the randomness definition a bit, in order to avoid locking, because using less +/// resources is the main goal of any DoS prevention mechanism. +/// In order to avoid locking and calls to rand we derive pseudo random values using the AEAD and +/// some counters. +pub struct RateLimiter { + /// The key we use to derive the nonce + nonce_key: [u8; 32], + /// The key we use to derive the cookie + secret_key: [u8; 16], + start_time: Instant, + /// A single 64 bit counter (should suffice for many years) + nonce_ctr: AtomicU64, + mac1_key: [u8; 32], + cookie_key: Key, + limit: u64, + /// The counter since last reset + count: AtomicU64, + /// The time last reset was performed on this rate limiter + last_reset: Mutex, +} + +impl RateLimiter { + pub fn new(public_key: &crate::x25519::PublicKey, limit: u64) -> Self { + let mut secret_key = [0u8; 16]; + OsRng.fill_bytes(&mut secret_key); + RateLimiter { + nonce_key: Self::rand_bytes(), + secret_key, + start_time: Instant::now(), + nonce_ctr: AtomicU64::new(0), + mac1_key: b2s_hash(LABEL_MAC1, public_key.as_bytes()), + cookie_key: b2s_hash(LABEL_COOKIE, public_key.as_bytes()).into(), + limit, + count: AtomicU64::new(0), + last_reset: Mutex::new(Instant::now()), + } + } + + fn rand_bytes() -> [u8; 32] { + let mut key = [0u8; 32]; + OsRng.fill_bytes(&mut key); + key + } + + /// Reset packet count (ideally should be called with a period of 1 second) + pub fn reset_count(&self) { + // The rate limiter is not very accurate, but at the scale we care about it doesn't matter much + let current_time = Instant::now(); + let mut last_reset_time = self.last_reset.lock(); + if current_time.duration_since(*last_reset_time).as_secs() >= RESET_PERIOD { + self.count.store(0, Ordering::SeqCst); + *last_reset_time = current_time; + } + } + + /// Compute the correct cookie value based on the current secret value and the source IP + fn current_cookie(&self, addr: IpAddr) -> Cookie { + let mut addr_bytes = [0u8; 16]; + + match addr { + IpAddr::V4(a) => addr_bytes[..4].copy_from_slice(&a.octets()[..]), + IpAddr::V6(a) => addr_bytes[..].copy_from_slice(&a.octets()[..]), + } + + // The current cookie for a given IP is the MAC(responder.changing_secret_every_two_minutes, initiator.ip_address) + // First we derive the secret from the current time, the value of cur_counter would change with time. + let cur_counter = Instant::now().duration_since(self.start_time).as_secs() / COOKIE_REFRESH; + + // Next we derive the cookie + b2s_keyed_mac_16_2(&self.secret_key, &cur_counter.to_le_bytes(), &addr_bytes) + } + + fn nonce(&self) -> [u8; COOKIE_NONCE_SIZE] { + let ctr = self.nonce_ctr.fetch_add(1, Ordering::Relaxed); + + b2s_mac_24(&self.nonce_key, &ctr.to_le_bytes()) + } + + fn is_under_load(&self) -> bool { + self.count.fetch_add(1, Ordering::SeqCst) >= self.limit + } + + pub(crate) fn format_cookie_reply<'a>( + &self, + idx: u32, + cookie: Cookie, + mac1: &[u8], + dst: &'a mut [u8], + ) -> Result<&'a mut [u8], WireGuardError> { + if dst.len() < super::COOKIE_REPLY_SZ { + return Err(WireGuardError::DestinationBufferTooSmall); + } + + let (message_type, rest) = dst.split_at_mut(4); + let (receiver_index, rest) = rest.split_at_mut(4); + let (nonce, rest) = rest.split_at_mut(24); + let (encrypted_cookie, _) = rest.split_at_mut(16 + 16); + + // msg.message_type = 3 + // msg.reserved_zero = { 0, 0, 0 } + message_type.copy_from_slice(&super::COOKIE_REPLY.to_le_bytes()); + // msg.receiver_index = little_endian(initiator.sender_index) + receiver_index.copy_from_slice(&idx.to_le_bytes()); + nonce.copy_from_slice(&self.nonce()[..]); + + let cipher = XChaCha20Poly1305::new(&self.cookie_key); + + let iv = GenericArray::from_slice(nonce); + + encrypted_cookie[..16].copy_from_slice(&cookie); + let tag = cipher + .encrypt_in_place_detached(iv, mac1, &mut encrypted_cookie[..16]) + .map_err(|_| WireGuardError::DestinationBufferTooSmall)?; + + encrypted_cookie[16..].copy_from_slice(&tag); + + Ok(&mut dst[..super::COOKIE_REPLY_SZ]) + } + + /// Verify the MAC fields on the datagram, and apply rate limiting if needed + pub fn verify_packet<'a, 'b>( + &self, + src_addr: Option, + src: &'a [u8], + dst: &'b mut [u8], + ) -> Result, TunnResult<'b>> { + let packet = Tunn::parse_incoming_packet(src)?; + + // Verify and rate limit handshake messages only + if let Packet::HandshakeInit(HandshakeInit { sender_idx, .. }) + | Packet::HandshakeResponse(HandshakeResponse { sender_idx, .. }) = packet + { + let (msg, macs) = src.split_at(src.len() - 32); + let (mac1, mac2) = macs.split_at(16); + + let computed_mac1 = b2s_keyed_mac_16(&self.mac1_key, msg); + verify_slices_are_equal(&computed_mac1[..16], mac1) + .map_err(|_| TunnResult::Err(WireGuardError::InvalidMac))?; + + if self.is_under_load() { + let addr = match src_addr { + None => return Err(TunnResult::Err(WireGuardError::UnderLoad)), + Some(addr) => addr, + }; + + // Only given an address can we validate mac2 + let cookie = self.current_cookie(addr); + let computed_mac2 = b2s_keyed_mac_16_2(&cookie, msg, mac1); + + if verify_slices_are_equal(&computed_mac2[..16], mac2).is_err() { + let cookie_packet = self + .format_cookie_reply(sender_idx, cookie, mac1, dst) + .map_err(TunnResult::Err)?; + return Err(TunnResult::WriteToNetwork(cookie_packet)); + } + } + } + + Ok(packet) + } +} diff --git a/burrow/src/boringtun/src/noise/session.rs b/burrow/src/boringtun/src/noise/session.rs new file mode 100755 index 0000000..0d05b95 --- /dev/null +++ b/burrow/src/boringtun/src/noise/session.rs @@ -0,0 +1,329 @@ +// Copyright (c) 2019 Cloudflare, Inc. All rights reserved. +// SPDX-License-Identifier: BSD-3-Clause + +use super::PacketData; +use crate::noise::errors::WireGuardError; +use parking_lot::Mutex; +use ring::aead::{Aad, LessSafeKey, Nonce, UnboundKey, CHACHA20_POLY1305}; +use std::sync::atomic::{AtomicUsize, Ordering}; + +pub struct Session { + pub(crate) receiving_index: u32, + sending_index: u32, + receiver: LessSafeKey, + sender: LessSafeKey, + sending_key_counter: AtomicUsize, + receiving_key_counter: Mutex, +} + +impl std::fmt::Debug for Session { + fn fmt(&self, f: &mut std::fmt::Formatter) -> std::fmt::Result { + write!( + f, + "Session: {}<- ->{}", + self.receiving_index, self.sending_index + ) + } +} + +/// Where encrypted data resides in a data packet +const DATA_OFFSET: usize = 16; +/// The overhead of the AEAD +const AEAD_SIZE: usize = 16; + +// Receiving buffer constants +const WORD_SIZE: u64 = 64; +const N_WORDS: u64 = 16; // Suffice to reorder 64*16 = 1024 packets; can be increased at will +const N_BITS: u64 = WORD_SIZE * N_WORDS; + +#[derive(Debug, Clone, Default)] +struct ReceivingKeyCounterValidator { + /// In order to avoid replays while allowing for some reordering of the packets, we keep a + /// bitmap of received packets, and the value of the highest counter + next: u64, + /// Used to estimate packet loss + receive_cnt: u64, + bitmap: [u64; N_WORDS as usize], +} + +impl ReceivingKeyCounterValidator { + #[inline(always)] + fn set_bit(&mut self, idx: u64) { + let bit_idx = idx % N_BITS; + let word = (bit_idx / WORD_SIZE) as usize; + let bit = (bit_idx % WORD_SIZE) as usize; + self.bitmap[word] |= 1 << bit; + } + + #[inline(always)] + fn clear_bit(&mut self, idx: u64) { + let bit_idx = idx % N_BITS; + let word = (bit_idx / WORD_SIZE) as usize; + let bit = (bit_idx % WORD_SIZE) as usize; + self.bitmap[word] &= !(1u64 << bit); + } + + /// Clear the word that contains idx + #[inline(always)] + fn clear_word(&mut self, idx: u64) { + let bit_idx = idx % N_BITS; + let word = (bit_idx / WORD_SIZE) as usize; + self.bitmap[word] = 0; + } + + /// Returns true if bit is set, false otherwise + #[inline(always)] + fn check_bit(&self, idx: u64) -> bool { + let bit_idx = idx % N_BITS; + let word = (bit_idx / WORD_SIZE) as usize; + let bit = (bit_idx % WORD_SIZE) as usize; + ((self.bitmap[word] >> bit) & 1) == 1 + } + + /// Returns true if the counter was not yet received, and is not too far back + #[inline(always)] + fn will_accept(&self, counter: u64) -> Result<(), WireGuardError> { + if counter >= self.next { + // As long as the counter is growing no replay took place for sure + return Ok(()); + } + if counter + N_BITS < self.next { + // Drop if too far back + return Err(WireGuardError::InvalidCounter); + } + if !self.check_bit(counter) { + Ok(()) + } else { + Err(WireGuardError::DuplicateCounter) + } + } + + /// Marks the counter as received, and returns true if it is still good (in case during + /// decryption something changed) + #[inline(always)] + fn mark_did_receive(&mut self, counter: u64) -> Result<(), WireGuardError> { + if counter + N_BITS < self.next { + // Drop if too far back + return Err(WireGuardError::InvalidCounter); + } + if counter == self.next { + // Usually the packets arrive in order, in that case we simply mark the bit and + // increment the counter + self.set_bit(counter); + self.next += 1; + return Ok(()); + } + if counter < self.next { + // A packet arrived out of order, check if it is valid, and mark + if self.check_bit(counter) { + return Err(WireGuardError::InvalidCounter); + } + self.set_bit(counter); + return Ok(()); + } + // Packets where dropped, or maybe reordered, skip them and mark unused + if counter - self.next >= N_BITS { + // Too far ahead, clear all the bits + for c in self.bitmap.iter_mut() { + *c = 0; + } + } else { + let mut i = self.next; + while i % WORD_SIZE != 0 && i < counter { + // Clear until i aligned to word size + self.clear_bit(i); + i += 1; + } + while i + WORD_SIZE < counter { + // Clear whole word at a time + self.clear_word(i); + i = (i + WORD_SIZE) & 0u64.wrapping_sub(WORD_SIZE); + } + while i < counter { + // Clear any remaining bits + self.clear_bit(i); + i += 1; + } + } + self.set_bit(counter); + self.next = counter + 1; + Ok(()) + } +} + +impl Session { + pub(super) fn new( + local_index: u32, + peer_index: u32, + receiving_key: [u8; 32], + sending_key: [u8; 32], + ) -> Session { + Session { + receiving_index: local_index, + sending_index: peer_index, + receiver: LessSafeKey::new( + UnboundKey::new(&CHACHA20_POLY1305, &receiving_key).unwrap(), + ), + sender: LessSafeKey::new(UnboundKey::new(&CHACHA20_POLY1305, &sending_key).unwrap()), + sending_key_counter: AtomicUsize::new(0), + receiving_key_counter: Mutex::new(Default::default()), + } + } + + pub(super) fn local_index(&self) -> usize { + self.receiving_index as usize + } + + /// Returns true if receiving counter is good to use + fn receiving_counter_quick_check(&self, counter: u64) -> Result<(), WireGuardError> { + let counter_validator = self.receiving_key_counter.lock(); + counter_validator.will_accept(counter) + } + + /// Returns true if receiving counter is good to use, and marks it as used { + fn receiving_counter_mark(&self, counter: u64) -> Result<(), WireGuardError> { + let mut counter_validator = self.receiving_key_counter.lock(); + let ret = counter_validator.mark_did_receive(counter); + if ret.is_ok() { + counter_validator.receive_cnt += 1; + } + ret + } + + /// src - an IP packet from the interface + /// dst - pre-allocated space to hold the encapsulating UDP packet to send over the network + /// returns the size of the formatted packet + pub(super) fn format_packet_data<'a>(&self, src: &[u8], dst: &'a mut [u8]) -> &'a mut [u8] { + if dst.len() < src.len() + super::DATA_OVERHEAD_SZ { + panic!("The destination buffer is too small"); + } + + let sending_key_counter = self.sending_key_counter.fetch_add(1, Ordering::Relaxed) as u64; + + let (message_type, rest) = dst.split_at_mut(4); + let (receiver_index, rest) = rest.split_at_mut(4); + let (counter, data) = rest.split_at_mut(8); + + message_type.copy_from_slice(&super::DATA.to_le_bytes()); + receiver_index.copy_from_slice(&self.sending_index.to_le_bytes()); + counter.copy_from_slice(&sending_key_counter.to_le_bytes()); + + // TODO: spec requires padding to 16 bytes, but actually works fine without it + let n = { + let mut nonce = [0u8; 12]; + nonce[4..12].copy_from_slice(&sending_key_counter.to_le_bytes()); + data[..src.len()].copy_from_slice(src); + self.sender + .seal_in_place_separate_tag( + Nonce::assume_unique_for_key(nonce), + Aad::from(&[]), + &mut data[..src.len()], + ) + .map(|tag| { + data[src.len()..src.len() + AEAD_SIZE].copy_from_slice(tag.as_ref()); + src.len() + AEAD_SIZE + }) + .unwrap() + }; + + &mut dst[..DATA_OFFSET + n] + } + + /// packet - a data packet we received from the network + /// dst - pre-allocated space to hold the encapsulated IP packet, to send to the interface + /// dst will always take less space than src + /// return the size of the encapsulated packet on success + pub(super) fn receive_packet_data<'a>( + &self, + packet: PacketData, + dst: &'a mut [u8], + ) -> Result<&'a mut [u8], WireGuardError> { + let ct_len = packet.encrypted_encapsulated_packet.len(); + if dst.len() < ct_len { + // This is a very incorrect use of the library, therefore panic and not error + panic!("The destination buffer is too small"); + } + if packet.receiver_idx != self.receiving_index { + return Err(WireGuardError::WrongIndex); + } + // Don't reuse counters, in case this is a replay attack we want to quickly check the counter without running expensive decryption + self.receiving_counter_quick_check(packet.counter)?; + + let ret = { + let mut nonce = [0u8; 12]; + nonce[4..12].copy_from_slice(&packet.counter.to_le_bytes()); + dst[..ct_len].copy_from_slice(packet.encrypted_encapsulated_packet); + self.receiver + .open_in_place( + Nonce::assume_unique_for_key(nonce), + Aad::from(&[]), + &mut dst[..ct_len], + ) + .map_err(|_| WireGuardError::InvalidAeadTag)? + }; + + // After decryption is done, check counter again, and mark as received + self.receiving_counter_mark(packet.counter)?; + Ok(ret) + } + + /// Returns the estimated downstream packet loss for this session + pub(super) fn current_packet_cnt(&self) -> (u64, u64) { + let counter_validator = self.receiving_key_counter.lock(); + (counter_validator.next, counter_validator.receive_cnt) + } +} + +#[cfg(test)] +mod tests { + use super::*; + #[test] + fn test_replay_counter() { + let mut c: ReceivingKeyCounterValidator = Default::default(); + + assert!(c.mark_did_receive(0).is_ok()); + assert!(c.mark_did_receive(0).is_err()); + assert!(c.mark_did_receive(1).is_ok()); + assert!(c.mark_did_receive(1).is_err()); + assert!(c.mark_did_receive(63).is_ok()); + assert!(c.mark_did_receive(63).is_err()); + assert!(c.mark_did_receive(15).is_ok()); + assert!(c.mark_did_receive(15).is_err()); + + for i in 64..N_BITS + 128 { + assert!(c.mark_did_receive(i).is_ok()); + assert!(c.mark_did_receive(i).is_err()); + } + + assert!(c.mark_did_receive(N_BITS * 3).is_ok()); + for i in 0..=N_BITS * 2 { + assert!(matches!( + c.will_accept(i), + Err(WireGuardError::InvalidCounter) + )); + assert!(c.mark_did_receive(i).is_err()); + } + for i in N_BITS * 2 + 1..N_BITS * 3 { + assert!(c.will_accept(i).is_ok()); + } + assert!(matches!( + c.will_accept(N_BITS * 3), + Err(WireGuardError::DuplicateCounter) + )); + + for i in (N_BITS * 2 + 1..N_BITS * 3).rev() { + assert!(c.mark_did_receive(i).is_ok()); + assert!(c.mark_did_receive(i).is_err()); + } + + assert!(c.mark_did_receive(N_BITS * 3 + 70).is_ok()); + assert!(c.mark_did_receive(N_BITS * 3 + 71).is_ok()); + assert!(c.mark_did_receive(N_BITS * 3 + 72).is_ok()); + assert!(c.mark_did_receive(N_BITS * 3 + 72 + 125).is_ok()); + assert!(c.mark_did_receive(N_BITS * 3 + 63).is_ok()); + + assert!(c.mark_did_receive(N_BITS * 3 + 70).is_err()); + assert!(c.mark_did_receive(N_BITS * 3 + 71).is_err()); + assert!(c.mark_did_receive(N_BITS * 3 + 72).is_err()); + } +} diff --git a/burrow/src/boringtun/src/noise/timers.rs b/burrow/src/boringtun/src/noise/timers.rs new file mode 100755 index 0000000..6b91d57 --- /dev/null +++ b/burrow/src/boringtun/src/noise/timers.rs @@ -0,0 +1,335 @@ +// Copyright (c) 2019 Cloudflare, Inc. All rights reserved. +// SPDX-License-Identifier: BSD-3-Clause + +use super::errors::WireGuardError; +use crate::noise::{Tunn, TunnResult}; +use std::mem; +use std::ops::{Index, IndexMut}; + +use std::time::Duration; + +#[cfg(feature = "mock-instant")] +use mock_instant::Instant; + +#[cfg(not(feature = "mock-instant"))] +use crate::sleepyinstant::Instant; + +// Some constants, represent time in seconds +// https://www.wireguard.com/papers/wireguard.pdf#page=14 +pub(crate) const REKEY_AFTER_TIME: Duration = Duration::from_secs(120); +const REJECT_AFTER_TIME: Duration = Duration::from_secs(180); +const REKEY_ATTEMPT_TIME: Duration = Duration::from_secs(90); +pub(crate) const REKEY_TIMEOUT: Duration = Duration::from_secs(5); +const KEEPALIVE_TIMEOUT: Duration = Duration::from_secs(10); +const COOKIE_EXPIRATION_TIME: Duration = Duration::from_secs(120); + +#[derive(Debug)] +pub enum TimerName { + /// Current time, updated each call to `update_timers` + TimeCurrent, + /// Time when last handshake was completed + TimeSessionEstablished, + /// Time the last attempt for a new handshake began + TimeLastHandshakeStarted, + /// Time we last received and authenticated a packet + TimeLastPacketReceived, + /// Time we last send a packet + TimeLastPacketSent, + /// Time we last received and authenticated a DATA packet + TimeLastDataPacketReceived, + /// Time we last send a DATA packet + TimeLastDataPacketSent, + /// Time we last received a cookie + TimeCookieReceived, + /// Time we last sent persistent keepalive + TimePersistentKeepalive, + Top, +} + +use self::TimerName::*; + +#[derive(Debug)] +pub struct Timers { + /// Is the owner of the timer the initiator or the responder for the last handshake? + is_initiator: bool, + /// Start time of the tunnel + time_started: Instant, + timers: [Duration; TimerName::Top as usize], + pub(super) session_timers: [Duration; super::N_SESSIONS], + /// Did we receive data without sending anything back? + want_keepalive: bool, + /// Did we send data without hearing back? + want_handshake: bool, + persistent_keepalive: usize, + /// Should this timer call reset rr function (if not a shared rr instance) + pub(super) should_reset_rr: bool, +} + +impl Timers { + pub(super) fn new(persistent_keepalive: Option, reset_rr: bool) -> Timers { + Timers { + is_initiator: false, + time_started: Instant::now(), + timers: Default::default(), + session_timers: Default::default(), + want_keepalive: Default::default(), + want_handshake: Default::default(), + persistent_keepalive: usize::from(persistent_keepalive.unwrap_or(0)), + should_reset_rr: reset_rr, + } + } + + fn is_initiator(&self) -> bool { + self.is_initiator + } + + // We don't really clear the timers, but we set them to the current time to + // so the reference time frame is the same + pub(super) fn clear(&mut self) { + let now = Instant::now().duration_since(self.time_started); + for t in &mut self.timers[..] { + *t = now; + } + self.want_handshake = false; + self.want_keepalive = false; + } +} + +impl Index for Timers { + type Output = Duration; + fn index(&self, index: TimerName) -> &Duration { + &self.timers[index as usize] + } +} + +impl IndexMut for Timers { + fn index_mut(&mut self, index: TimerName) -> &mut Duration { + &mut self.timers[index as usize] + } +} + +impl Tunn { + pub(super) fn timer_tick(&mut self, timer_name: TimerName) { + match timer_name { + TimeLastPacketReceived => { + self.timers.want_keepalive = true; + self.timers.want_handshake = false; + } + TimeLastPacketSent => { + self.timers.want_handshake = true; + self.timers.want_keepalive = false; + } + _ => {} + } + + let time = self.timers[TimeCurrent]; + self.timers[timer_name] = time; + } + + pub(super) fn timer_tick_session_established( + &mut self, + is_initiator: bool, + session_idx: usize, + ) { + self.timer_tick(TimeSessionEstablished); + self.timers.session_timers[session_idx % crate::noise::N_SESSIONS] = + self.timers[TimeCurrent]; + self.timers.is_initiator = is_initiator; + } + + // We don't really clear the timers, but we set them to the current time to + // so the reference time frame is the same + fn clear_all(&mut self) { + for session in &mut self.sessions { + *session = None; + } + + self.packet_queue.clear(); + + self.timers.clear(); + } + + fn update_session_timers(&mut self, time_now: Duration) { + let timers = &mut self.timers; + + for (i, t) in timers.session_timers.iter_mut().enumerate() { + if time_now - *t > REJECT_AFTER_TIME { + if let Some(session) = self.sessions[i].take() { + tracing::debug!( + message = "SESSION_EXPIRED(REJECT_AFTER_TIME)", + session = session.receiving_index + ); + } + *t = time_now; + } + } + } + + pub fn update_timers<'a>(&mut self, dst: &'a mut [u8]) -> TunnResult<'a> { + let mut handshake_initiation_required = false; + let mut keepalive_required = false; + + let time = Instant::now(); + + if self.timers.should_reset_rr { + self.rate_limiter.reset_count(); + } + + // All the times are counted from tunnel initiation, for efficiency our timers are rounded + // to a second, as there is no real benefit to having highly accurate timers. + let now = time.duration_since(self.timers.time_started); + self.timers[TimeCurrent] = now; + + self.update_session_timers(now); + + // Load timers only once: + let session_established = self.timers[TimeSessionEstablished]; + let handshake_started = self.timers[TimeLastHandshakeStarted]; + let aut_packet_received = self.timers[TimeLastPacketReceived]; + let aut_packet_sent = self.timers[TimeLastPacketSent]; + let data_packet_received = self.timers[TimeLastDataPacketReceived]; + let data_packet_sent = self.timers[TimeLastDataPacketSent]; + let persistent_keepalive = self.timers.persistent_keepalive; + + { + if self.handshake.is_expired() { + return TunnResult::Err(WireGuardError::ConnectionExpired); + } + + // Clear cookie after COOKIE_EXPIRATION_TIME + if self.handshake.has_cookie() + && now - self.timers[TimeCookieReceived] >= COOKIE_EXPIRATION_TIME + { + self.handshake.clear_cookie(); + } + + // All ephemeral private keys and symmetric session keys are zeroed out after + // (REJECT_AFTER_TIME * 3) ms if no new keys have been exchanged. + if now - session_established >= REJECT_AFTER_TIME * 3 { + tracing::error!("CONNECTION_EXPIRED(REJECT_AFTER_TIME * 3)"); + self.handshake.set_expired(); + self.clear_all(); + return TunnResult::Err(WireGuardError::ConnectionExpired); + } + + if let Some(time_init_sent) = self.handshake.timer() { + // Handshake Initiation Retransmission + if now - handshake_started >= REKEY_ATTEMPT_TIME { + // After REKEY_ATTEMPT_TIME ms of trying to initiate a new handshake, + // the retries give up and cease, and clear all existing packets queued + // up to be sent. If a packet is explicitly queued up to be sent, then + // this timer is reset. + tracing::error!("CONNECTION_EXPIRED(REKEY_ATTEMPT_TIME)"); + self.handshake.set_expired(); + self.clear_all(); + return TunnResult::Err(WireGuardError::ConnectionExpired); + } + + if time_init_sent.elapsed() >= REKEY_TIMEOUT { + // We avoid using `time` here, because it can be earlier than `time_init_sent`. + // Once `checked_duration_since` is stable we can use that. + // A handshake initiation is retried after REKEY_TIMEOUT + jitter ms, + // if a response has not been received, where jitter is some random + // value between 0 and 333 ms. + tracing::warn!("HANDSHAKE(REKEY_TIMEOUT)"); + handshake_initiation_required = true; + } + } else { + if self.timers.is_initiator() { + // After sending a packet, if the sender was the original initiator + // of the handshake and if the current session key is REKEY_AFTER_TIME + // ms old, we initiate a new handshake. If the sender was the original + // responder of the handshake, it does not re-initiate a new handshake + // after REKEY_AFTER_TIME ms like the original initiator does. + if session_established < data_packet_sent + && now - session_established >= REKEY_AFTER_TIME + { + tracing::debug!("HANDSHAKE(REKEY_AFTER_TIME (on send))"); + handshake_initiation_required = true; + } + + // After receiving a packet, if the receiver was the original initiator + // of the handshake and if the current session key is REJECT_AFTER_TIME + // - KEEPALIVE_TIMEOUT - REKEY_TIMEOUT ms old, we initiate a new + // handshake. + if session_established < data_packet_received + && now - session_established + >= REJECT_AFTER_TIME - KEEPALIVE_TIMEOUT - REKEY_TIMEOUT + { + tracing::warn!( + "HANDSHAKE(REJECT_AFTER_TIME - KEEPALIVE_TIMEOUT - \ + REKEY_TIMEOUT \ + (on receive))" + ); + handshake_initiation_required = true; + } + } + + // If we have sent a packet to a given peer but have not received a + // packet after from that peer for (KEEPALIVE + REKEY_TIMEOUT) ms, + // we initiate a new handshake. + if data_packet_sent > aut_packet_received + && now - aut_packet_received >= KEEPALIVE_TIMEOUT + REKEY_TIMEOUT + && mem::replace(&mut self.timers.want_handshake, false) + { + tracing::warn!("HANDSHAKE(KEEPALIVE + REKEY_TIMEOUT)"); + handshake_initiation_required = true; + } + + if !handshake_initiation_required { + // If a packet has been received from a given peer, but we have not sent one back + // to the given peer in KEEPALIVE ms, we send an empty packet. + if data_packet_received > aut_packet_sent + && now - aut_packet_sent >= KEEPALIVE_TIMEOUT + && mem::replace(&mut self.timers.want_keepalive, false) + { + tracing::debug!("KEEPALIVE(KEEPALIVE_TIMEOUT)"); + keepalive_required = true; + } + + // Persistent KEEPALIVE + if persistent_keepalive > 0 + && (now - self.timers[TimePersistentKeepalive] + >= Duration::from_secs(persistent_keepalive as _)) + { + tracing::debug!("KEEPALIVE(PERSISTENT_KEEPALIVE)"); + self.timer_tick(TimePersistentKeepalive); + keepalive_required = true; + } + } + } + } + + if handshake_initiation_required { + return self.format_handshake_initiation(dst, true); + } + + if keepalive_required { + return self.encapsulate(&[], dst); + } + + TunnResult::Done + } + + pub fn time_since_last_handshake(&self) -> Option { + let current_session = self.current; + if self.sessions[current_session % super::N_SESSIONS].is_some() { + let duration_since_tun_start = Instant::now().duration_since(self.timers.time_started); + let duration_since_session_established = self.timers[TimeSessionEstablished]; + + Some(duration_since_tun_start - duration_since_session_established) + } else { + None + } + } + + pub fn persistent_keepalive(&self) -> Option { + let keepalive = self.timers.persistent_keepalive; + + if keepalive > 0 { + Some(keepalive as u16) + } else { + None + } + } +} diff --git a/burrow/src/boringtun/src/serialization.rs b/burrow/src/boringtun/src/serialization.rs new file mode 100755 index 0000000..e6920f8 --- /dev/null +++ b/burrow/src/boringtun/src/serialization.rs @@ -0,0 +1,33 @@ +pub(crate) struct KeyBytes(pub [u8; 32]); + +impl std::str::FromStr for KeyBytes { + type Err = &'static str; + + /// Can parse a secret key from a hex or base64 encoded string. + fn from_str(s: &str) -> Result { + let mut internal = [0u8; 32]; + + match s.len() { + 64 => { + // Try to parse as hex + for i in 0..32 { + internal[i] = u8::from_str_radix(&s[i * 2..=i * 2 + 1], 16) + .map_err(|_| "Illegal character in key")?; + } + } + 43 | 44 => { + // Try to parse as base64 + if let Ok(decoded_key) = base64::decode(s) { + if decoded_key.len() == internal.len() { + internal[..].copy_from_slice(&decoded_key); + } else { + return Err("Illegal character in key"); + } + } + } + _ => return Err("Illegal key size"), + } + + Ok(KeyBytes(internal)) + } +}