add makefile

This commit is contained in:
Jett Chen 2023-12-17 01:26:37 +08:00
parent 94233874e6
commit 261f24d9ef
24 changed files with 207 additions and 293 deletions

View file

@ -1,8 +1,6 @@
use tracing::instrument::WithSubscriber;
use tracing::{debug, Subscriber};
use tracing::debug;
use tracing_oslog::OsLogger;
use tracing_subscriber::layer::SubscriberExt;
use tracing_subscriber::FmtSubscriber;
pub use crate::daemon::start_srv;

View file

@ -23,10 +23,7 @@ fn test_daemoncommand_serialization() {
.unwrap());
insta::assert_snapshot!(
serde_json::to_string(&DaemonCommand::Start(DaemonStartOptions {
tun: TunOptions {
seek_utun: true,
..TunOptions::default()
}
tun: TunOptions { ..TunOptions::default() }
}))
.unwrap()
);

View file

@ -1,10 +1,17 @@
use super::*;
use crate::daemon::response::{DaemonResponseData, ServerConfig, ServerInfo};
use tokio::task::JoinHandle;
use tracing::field::debug;
use std::sync::Arc;
use anyhow::Result;
use tokio::{sync::RwLock, task::JoinHandle};
use tracing::{debug, info, warn};
use tun::tokio::TunInterface;
use DaemonResponse;
use crate::{
daemon::{
command::DaemonCommand,
response::{DaemonResponse, DaemonResponseData, ServerConfig, ServerInfo},
},
wireguard::Interface,
};
enum RunState {
Running(JoinHandle<Result<()>>),
@ -48,7 +55,9 @@ impl DaemonInstance {
}
RunState::Idle => {
debug!("Creating new TunInterface");
let tun_if = Arc::new(RwLock::new(TunInterface::new(st.tun.open()?)?));
let retrieved = TunInterface::new(tun::TunInterface::retrieve().unwrap())?;
let tun_if = Arc::new(RwLock::new(retrieved));
// let tun_if = Arc::new(RwLock::new(TunInterface::new(st.tun.open()?)?));
debug!("TunInterface created: {:?}", tun_if);
debug!("Setting tun_interface");

View file

@ -1,26 +1,29 @@
use std::net::{IpAddr, Ipv4Addr, SocketAddr, SocketAddrV4, ToSocketAddrs};
use std::sync::Arc;
use std::{
net::{IpAddr, Ipv4Addr, SocketAddr},
sync::Arc,
};
mod command;
mod instance;
mod net;
mod response;
use crate::wireguard::{Interface, Peer, PublicKey, StaticSecret};
use anyhow::{Error, Result};
use base64::{engine::general_purpose, Engine as _};
pub use command::{DaemonCommand, DaemonStartOptions};
use fehler::throws;
use instance::DaemonInstance;
use ip_network::{IpNetwork, Ipv4Network};
pub use net::DaemonClient;
use tokio::sync::RwLock;
#[cfg(target_vendor = "apple")]
pub use net::start_srv;
use crate::daemon::net::listen;
pub use net::DaemonClient;
pub use response::{DaemonResponse, DaemonResponseData, ServerInfo};
use tokio::sync::RwLock;
use crate::{
daemon::net::listen,
wireguard::{Interface, Peer, PublicKey, StaticSecret},
};
#[throws]
fn parse_key(string: &str) -> [u8; 32] {

View file

@ -1,9 +1,10 @@
use crate::daemon::{daemon_main, DaemonClient};
use std::future::Future;
use std::thread;
use tokio::runtime::Runtime;
use tracing::{error, info};
use crate::daemon::{daemon_main, DaemonClient};
#[no_mangle]
pub extern "C" fn start_srv() {
info!("Rust: Starting server");
@ -20,7 +21,7 @@ pub extern "C" fn start_srv() {
loop {
match DaemonClient::new().await {
Ok(_) => break,
Err(e) => {
Err(_e) => {
// error!("Error when connecting to daemon: {}", e)
}
}

View file

@ -1,23 +1,21 @@
use super::*;
use anyhow::anyhow;
use log::log;
use std::hash::Hash;
use std::path::PathBuf;
use std::{
ascii, io,
io,
os::{
fd::{FromRawFd, RawFd},
unix::net::UnixListener as StdUnixListener,
},
path::Path};
path::{Path, PathBuf},
};
use tracing::info;
use anyhow::Result;
use crate::daemon::{DaemonCommand, DaemonResponse, DaemonResponseData};
use anyhow::{anyhow, Result};
use tokio::{
io::{AsyncBufReadExt, AsyncWriteExt, BufReader},
net::{UnixListener, UnixStream},
};
use tracing::{debug, info};
use crate::daemon::{DaemonCommand, DaemonResponse, DaemonResponseData};
use tracing::debug;
use super::*;
#[cfg(not(target_vendor = "apple"))]
const UNIX_SOCKET_PATH: &str = "/run/burrow.sock";
@ -38,7 +36,7 @@ fn fetch_socket_path() -> Option<PathBuf> {
for path in tries {
let path = PathBuf::from(path);
if path.exists() {
return Some(path);
return Some(path)
}
}
None

View file

@ -1,4 +1,3 @@
use anyhow::anyhow;
use schemars::JsonSchema;
use serde::{Deserialize, Serialize};
use tun::TunInterface;

View file

@ -1,40 +0,0 @@
use tracing::instrument;
// Check capabilities on Linux
#[cfg(target_os = "linux")]
#[instrument]
pub fn ensure_root() {
use caps::{has_cap, CapSet, Capability};
let cap_net_admin = Capability::CAP_NET_ADMIN;
if let Ok(has_cap) = has_cap(None, CapSet::Effective, cap_net_admin) {
if !has_cap {
eprintln!(
"This action needs the CAP_NET_ADMIN permission. Did you mean to run it as root?"
);
std::process::exit(77);
}
} else {
eprintln!("Failed to check capabilities. Please file a bug report!");
std::process::exit(71);
}
}
// Check for root user on macOS
#[cfg(target_vendor = "apple")]
#[instrument]
pub fn ensure_root() {
use nix::unistd::Uid;
let current_uid = Uid::current();
if !current_uid.is_root() {
eprintln!("This action must be run as root!");
std::process::exit(77);
}
}
#[cfg(target_family = "windows")]
#[instrument]
pub fn ensure_root() {
todo!()
}

View file

@ -1,21 +1,12 @@
pub mod ensureroot;
pub mod wireguard;
use anyhow::Result;
#[cfg(any(target_os = "linux", target_vendor = "apple"))]
use std::{
mem,
os::fd::{AsRawFd, FromRawFd},
};
use tun::TunInterface;
// TODO Separate start and retrieve functions
mod daemon;
pub use daemon::{
DaemonClient, DaemonCommand, DaemonResponse, DaemonResponseData, DaemonStartOptions, ServerInfo,
DaemonClient, DaemonCommand,
DaemonResponse,
DaemonResponseData,
DaemonStartOptions,
ServerInfo,
};
#[cfg(target_vendor = "apple")]

View file

@ -1,24 +1,21 @@
use std::mem;
#[cfg(any(target_os = "linux", target_vendor = "apple"))]
use std::os::fd::FromRawFd;
use anyhow::{Context, Result};
#[cfg(any(target_os = "linux", target_vendor = "apple"))]
use clap::{Args, Parser, Subcommand};
use tracing::{instrument, Level};
use tracing::instrument;
use tracing_log::LogTracer;
use tracing_oslog::OsLogger;
use tracing_subscriber::{prelude::*, EnvFilter, FmtSubscriber};
#[cfg(any(target_os = "linux", target_vendor = "apple"))]
use tun::{retrieve, TunInterface};
use tun::TunInterface;
mod daemon;
mod wireguard;
use crate::daemon::DaemonResponseData;
use daemon::{DaemonClient, DaemonCommand, DaemonStartOptions};
use tun::TunOptions;
use crate::daemon::DaemonResponseData;
#[derive(Parser)]
#[command(name = "Burrow")]
#[command(author = "Hack Club <team@hackclub.com>")]
@ -88,8 +85,7 @@ async fn try_retrieve() -> Result<()> {
}
}
burrow::ensureroot::ensure_root();
let iface2 = retrieve().ok_or(anyhow::anyhow!("No interface found"))?;
let iface2 = TunInterface::retrieve().ok_or(anyhow::anyhow!("No interface found"))?;
tracing::info!("{:?}", iface2);
Ok(())
}

View file

@ -1,24 +1,15 @@
use std::sync::Arc;
use std::time::Duration;
use std::{net::IpAddr, rc::Rc};
use std::{net::IpAddr, sync::Arc, time::Duration};
use anyhow::Error;
use async_trait::async_trait;
use fehler::throws;
use futures::future::join_all;
use futures::FutureExt;
use futures::{future::join_all, FutureExt};
use ip_network_table::IpNetworkTable;
use log::log;
use tokio::time::timeout;
use tokio::{
join,
sync::{Mutex, RwLock},
task::{self, JoinHandle},
};
use tokio::{sync::RwLock, task::JoinHandle, time::timeout};
use tracing::{debug, error};
use tun::tokio::TunInterface;
use super::{noise::Tunnel, pcb, Peer, PeerPcb};
use super::{noise::Tunnel, Peer, PeerPcb};
#[async_trait]
pub trait PacketInterface {
@ -122,12 +113,9 @@ impl Interface {
Ok(Ok(len)) => &buf[..len],
Ok(Err(e)) => {
error!("failed to read from interface: {}", e);
continue;
}
Err(_would_block) => {
debug!("read timed out");
continue;
continue
}
Err(_would_block) => continue,
};
debug!("read {} bytes from interface", src.len());
debug!("bytes: {:?}", src);
@ -138,7 +126,7 @@ impl Interface {
Some(addr) => addr,
None => {
tracing::debug!("no destination found");
continue;
continue
}
};
@ -156,7 +144,7 @@ impl Interface {
}
Err(e) => {
log::error!("failed to send packet {}", e);
continue;
continue
}
};
}
@ -175,20 +163,20 @@ impl Interface {
let pcbs = &self.pcbs;
for i in 0..pcbs.pcbs.len() {
debug!("spawning read task for peer {}", i);
let mut pcb = pcbs.pcbs[i].clone();
let pcb = pcbs.pcbs[i].clone();
let tun = tun.clone();
let tsk = async move {
{
let r1 = pcb.write().await.open_if_closed().await;
if let Err(e) = r1 {
log::error!("failed to open pcb: {}", e);
return;
return
}
}
let r2 = pcb.read().await.run(tun).await;
if let Err(e) = r2 {
log::error!("failed to run pcb: {}", e);
return;
return
} else {
tracing::debug!("pcb ran successfully");
}

View file

@ -9,14 +9,20 @@ use std::{
use aead::{Aead, Payload};
use blake2::{
digest::{FixedOutput, KeyInit},
Blake2s256, Blake2sMac, Digest,
Blake2s256,
Blake2sMac,
Digest,
};
use chacha20poly1305::XChaCha20Poly1305;
use rand_core::OsRng;
use ring::aead::{Aad, LessSafeKey, Nonce, UnboundKey, CHACHA20_POLY1305};
use super::{
errors::WireGuardError, session::Session, x25519, HandshakeInit, HandshakeResponse,
errors::WireGuardError,
session::Session,
x25519,
HandshakeInit,
HandshakeResponse,
PacketCookieReply,
};
@ -130,10 +136,6 @@ fn aead_chacha20_open(
) -> Result<(), WireGuardError> {
let mut nonce: [u8; 12] = [0; 12];
nonce[4..].copy_from_slice(&counter.to_le_bytes());
tracing::debug!("TAG A");
tracing::debug!("{:?};{:?};{:?};{:?};{}", key, data, aad, nonce, counter);
aead_chacha20_open_inner(buffer, key, nonce, data, aad)
.map_err(|_| WireGuardError::InvalidAeadTag)?;
Ok(())
@ -207,7 +209,7 @@ impl Tai64N {
/// Parse a timestamp from a 12 byte u8 slice
fn parse(buf: &[u8; 12]) -> Result<Tai64N, WireGuardError> {
if buf.len() < 12 {
return Err(WireGuardError::InvalidTai64nTimestamp);
return Err(WireGuardError::InvalidTai64nTimestamp)
}
let (sec_bytes, nano_bytes) = buf.split_at(std::mem::size_of::<u64>());
@ -554,22 +556,19 @@ impl Handshake {
let timestamp = Tai64N::parse(&timestamp)?;
if !timestamp.after(&self.last_handshake_timestamp) {
// Possibly a replay
return Err(WireGuardError::WrongTai64nTimestamp);
return Err(WireGuardError::WrongTai64nTimestamp)
}
self.last_handshake_timestamp = timestamp;
// initiator.hash = HASH(initiator.hash || msg.encrypted_timestamp)
hash = b2s_hash(&hash, packet.encrypted_timestamp);
self.previous = std::mem::replace(
&mut self.state,
HandshakeState::InitReceived {
chaining_key,
hash,
peer_ephemeral_public,
peer_index,
},
);
self.previous = std::mem::replace(&mut self.state, HandshakeState::InitReceived {
chaining_key,
hash,
peer_ephemeral_public,
peer_index,
});
self.format_handshake_response(dst)
}
@ -670,7 +669,7 @@ impl Handshake {
let local_index = self.cookies.index;
if packet.receiver_idx != local_index {
return Err(WireGuardError::WrongIndex);
return Err(WireGuardError::WrongIndex)
}
// msg.encrypted_cookie = XAEAD(HASH(LABEL_COOKIE || responder.static_public),
// msg.nonce, cookie, last_received_msg.mac1)
@ -680,7 +679,6 @@ impl Handshake {
aad: &mac1[0..16],
msg: packet.encrypted_cookie,
};
tracing::debug!("TAG B");
let plaintext = XChaCha20Poly1305::new_from_slice(&key)
.unwrap()
.decrypt(packet.nonce.into(), payload)
@ -727,7 +725,7 @@ impl Handshake {
dst: &'a mut [u8],
) -> Result<&'a mut [u8], WireGuardError> {
if dst.len() < super::HANDSHAKE_INIT_SZ {
return Err(WireGuardError::DestinationBufferTooSmall);
return Err(WireGuardError::DestinationBufferTooSmall)
}
let (message_type, rest) = dst.split_at_mut(4);
@ -810,7 +808,7 @@ impl Handshake {
dst: &'a mut [u8],
) -> Result<(&'a mut [u8], Session), WireGuardError> {
if dst.len() < super::HANDSHAKE_RESP_SZ {
return Err(WireGuardError::DestinationBufferTooSmall);
return Err(WireGuardError::DestinationBufferTooSmall)
}
let state = std::mem::replace(&mut self.state, HandshakeState::None);

View file

@ -45,7 +45,11 @@ const N_SESSIONS: usize = 8;
pub mod x25519 {
pub use x25519_dalek::{
EphemeralSecret, PublicKey, ReusableSecret, SharedSecret, StaticSecret,
EphemeralSecret,
PublicKey,
ReusableSecret,
SharedSecret,
StaticSecret,
};
}
@ -137,7 +141,7 @@ impl Tunnel {
#[inline(always)]
pub fn parse_incoming_packet(src: &[u8]) -> Result<Packet, WireGuardError> {
if src.len() < 4 {
return Err(WireGuardError::InvalidPacket);
return Err(WireGuardError::InvalidPacket)
}
// Checks the type, as well as the reserved zero fields
@ -179,7 +183,7 @@ impl Tunnel {
pub fn dst_address(packet: &[u8]) -> Option<IpAddr> {
if packet.is_empty() {
return None;
return None
}
match packet[0] >> 4 {
@ -274,7 +278,7 @@ impl Tunnel {
self.timer_tick(TimerName::TimeLastDataPacketSent);
}
self.tx_bytes += src.len();
return TunnResult::WriteToNetwork(packet);
return TunnResult::WriteToNetwork(packet)
}
// If there is no session, queue the packet for future retry
@ -298,7 +302,7 @@ impl Tunnel {
) -> TunnResult<'a> {
if datagram.is_empty() {
// Indicates a repeated call
return self.send_queued_packet(dst);
return self.send_queued_packet(dst)
}
let mut cookie = [0u8; COOKIE_REPLY_SZ];
@ -309,7 +313,7 @@ impl Tunnel {
Ok(packet) => packet,
Err(TunnResult::WriteToNetwork(cookie)) => {
dst[..cookie.len()].copy_from_slice(cookie);
return TunnResult::WriteToNetwork(&mut dst[..cookie.len()]);
return TunnResult::WriteToNetwork(&mut dst[..cookie.len()])
}
Err(TunnResult::Err(e)) => return TunnResult::Err(e),
_ => unreachable!(),
@ -409,7 +413,7 @@ impl Tunnel {
let cur_idx = self.current;
if cur_idx == new_idx {
// There is nothing to do, already using this session, this is the common case
return;
return
}
if self.sessions[cur_idx % N_SESSIONS].is_none()
|| self.timers.session_timers[new_idx % N_SESSIONS]
@ -455,7 +459,7 @@ impl Tunnel {
force_resend: bool,
) -> TunnResult<'a> {
if self.handshake.is_in_progress() && !force_resend {
return TunnResult::Done;
return TunnResult::Done
}
if self.handshake.is_expired() {
@ -514,7 +518,7 @@ impl Tunnel {
};
if computed_len > packet.len() {
return TunnResult::Err(WireGuardError::InvalidPacket);
return TunnResult::Err(WireGuardError::InvalidPacket)
}
self.timer_tick(TimerName::TimeLastDataPacketReceived);

View file

@ -6,16 +6,25 @@ use std::{
use aead::{generic_array::GenericArray, AeadInPlace, KeyInit};
use chacha20poly1305::{Key, XChaCha20Poly1305};
use log::log;
use parking_lot::Mutex;
use rand_core::{OsRng, RngCore};
use ring::constant_time::verify_slices_are_equal;
use super::{
handshake::{
b2s_hash, b2s_keyed_mac_16, b2s_keyed_mac_16_2, b2s_mac_24, LABEL_COOKIE, LABEL_MAC1,
b2s_hash,
b2s_keyed_mac_16,
b2s_keyed_mac_16_2,
b2s_mac_24,
LABEL_COOKIE,
LABEL_MAC1,
},
HandshakeInit, HandshakeResponse, Packet, TunnResult, Tunnel, WireGuardError,
HandshakeInit,
HandshakeResponse,
Packet,
TunnResult,
Tunnel,
WireGuardError,
};
const COOKIE_REFRESH: u64 = 128; // Use 128 and not 120 so the compiler can optimize out the division
@ -127,7 +136,7 @@ impl RateLimiter {
dst: &'a mut [u8],
) -> Result<&'a mut [u8], WireGuardError> {
if dst.len() < super::COOKIE_REPLY_SZ {
return Err(WireGuardError::DestinationBufferTooSmall);
return Err(WireGuardError::DestinationBufferTooSmall)
}
let (message_type, rest) = dst.split_at_mut(4);
@ -193,7 +202,7 @@ impl RateLimiter {
let cookie_packet = self
.format_cookie_reply(sender_idx, cookie, mac1, dst)
.map_err(TunnResult::Err)?;
return Err(TunnResult::WriteToNetwork(cookie_packet));
return Err(TunnResult::WriteToNetwork(cookie_packet))
}
}
}

View file

@ -88,11 +88,11 @@ impl ReceivingKeyCounterValidator {
fn will_accept(&self, counter: u64) -> Result<(), WireGuardError> {
if counter >= self.next {
// As long as the counter is growing no replay took place for sure
return Ok(());
return Ok(())
}
if counter + N_BITS < self.next {
// Drop if too far back
return Err(WireGuardError::InvalidCounter);
return Err(WireGuardError::InvalidCounter)
}
if !self.check_bit(counter) {
Ok(())
@ -107,22 +107,22 @@ impl ReceivingKeyCounterValidator {
fn mark_did_receive(&mut self, counter: u64) -> Result<(), WireGuardError> {
if counter + N_BITS < self.next {
// Drop if too far back
return Err(WireGuardError::InvalidCounter);
return Err(WireGuardError::InvalidCounter)
}
if counter == self.next {
// Usually the packets arrive in order, in that case we simply mark the bit and
// increment the counter
self.set_bit(counter);
self.next += 1;
return Ok(());
return Ok(())
}
if counter < self.next {
// A packet arrived out of order, check if it is valid, and mark
if self.check_bit(counter) {
return Err(WireGuardError::InvalidCounter);
return Err(WireGuardError::InvalidCounter)
}
self.set_bit(counter);
return Ok(());
return Ok(())
}
// Packets where dropped, or maybe reordered, skip them and mark unused
if counter - self.next >= N_BITS {
@ -247,7 +247,7 @@ impl Session {
panic!("The destination buffer is too small");
}
if packet.receiver_idx != self.receiving_index {
return Err(WireGuardError::WrongIndex);
return Err(WireGuardError::WrongIndex)
}
// Don't reuse counters, in case this is a replay attack we want to quickly
// check the counter without running expensive decryption

View file

@ -190,7 +190,7 @@ impl Tunnel {
{
if self.handshake.is_expired() {
return TunnResult::Err(WireGuardError::ConnectionExpired);
return TunnResult::Err(WireGuardError::ConnectionExpired)
}
// Clear cookie after COOKIE_EXPIRATION_TIME
@ -206,7 +206,7 @@ impl Tunnel {
tracing::error!("CONNECTION_EXPIRED(REJECT_AFTER_TIME * 3)");
self.handshake.set_expired();
self.clear_all();
return TunnResult::Err(WireGuardError::ConnectionExpired);
return TunnResult::Err(WireGuardError::ConnectionExpired)
}
if let Some(time_init_sent) = self.handshake.timer() {
@ -219,7 +219,7 @@ impl Tunnel {
tracing::error!("CONNECTION_EXPIRED(REKEY_ATTEMPT_TIME)");
self.handshake.set_expired();
self.clear_all();
return TunnResult::Err(WireGuardError::ConnectionExpired);
return TunnResult::Err(WireGuardError::ConnectionExpired)
}
if time_init_sent.elapsed() >= REKEY_TIMEOUT {
@ -299,11 +299,11 @@ impl Tunnel {
}
if handshake_initiation_required {
return self.format_handshake_initiation(dst, true);
return self.format_handshake_initiation(dst, true)
}
if keepalive_required {
return self.encapsulate(&[], dst);
return self.encapsulate(&[], dst)
}
TunnResult::Done

View file

@ -1,19 +1,11 @@
use std::io;
use std::net::SocketAddr;
use std::rc::Rc;
use std::sync::Arc;
use std::time::Duration;
use std::{net::SocketAddr, sync::Arc};
use anyhow::{anyhow, Error};
use fehler::throws;
use ip_network::IpNetwork;
use log::log;
use rand::random;
use tokio::sync::{Mutex, RwLock};
use tokio::time::timeout;
use tokio::{net::UdpSocket, task::JoinHandle};
use tokio::{net::UdpSocket, sync::RwLock, task::JoinHandle};
use tun::tokio::TunInterface;
use uuid::uuid;
use super::{
iface::PacketInterface,
@ -83,16 +75,14 @@ impl PeerPcb {
tracing::debug!("start read loop {}", rid);
loop {
tracing::debug!("{}: waiting for packet", rid);
let Some(socket) = &self.socket else {
continue
};
let Some(socket) = &self.socket else { continue };
let mut res_buf = [0; 1500];
// tracing::debug!("{} : waiting for readability on {:?}", rid, socket);
let len = match socket.recv(&mut res_buf).await {
Ok(l) => l,
Err(e) => {
log::error!("{}: error reading from socket: {:?}", rid, e);
continue;
continue
}
};
let mut res_dat = &res_buf[..len];
@ -105,33 +95,31 @@ impl PeerPcb {
.await
.decapsulate(None, res_dat, &mut buf[..])
{
TunnResult::Done => {
break;
}
TunnResult::Done => break,
TunnResult::Err(e) => {
tracing::error!(message = "Decapsulate error", error = ?e);
break;
break
}
TunnResult::WriteToNetwork(packet) => {
tracing::debug!("WriteToNetwork: {:?}", packet);
socket.send(packet).await?;
tracing::debug!("WriteToNetwork done");
res_dat = &[];
continue;
continue
}
TunnResult::WriteToTunnelV4(packet, addr) => {
tracing::debug!("WriteToTunnelV4: {:?}, {:?}", packet, addr);
tun_interface.read().await.send(packet).await?;
break;
break
}
TunnResult::WriteToTunnelV6(packet, addr) => {
tracing::debug!("WriteToTunnelV6: {:?}, {:?}", packet, addr);
tun_interface.read().await.send(packet).await?;
break;
break
}
}
}
return Ok(len);
return Ok(len)
}
}

View file

@ -1,7 +1,5 @@
use std::{fmt, net::SocketAddr};
use anyhow::Error;
use fehler::throws;
use ip_network::IpNetwork;
use x25519_dalek::{PublicKey, StaticSecret};