Add Wireguard support to Burrow

This commit is contained in:
Jett Chen 2023-12-17 01:20:56 +08:00 committed by Conrad Kramer
parent 60257b256a
commit d3448e2bc7
59 changed files with 3805 additions and 521 deletions

View file

@ -39,5 +39,5 @@ anyhow = "1.0"
bindgen = "0.65"
reqwest = { version = "0.11", features = ["native-tls"] }
ssri = { version = "9.0", default-features = false }
tokio = { version = "1.28", features = ["rt"] }
tokio = { version = "1.28", features = ["rt", "macros"] }
zip = { version = "0.6", features = ["deflate"] }

View file

@ -26,7 +26,7 @@ async fn generate(out_dir: &std::path::Path) -> anyhow::Result<()> {
println!("cargo:rerun-if-changed={}", binary_path.to_str().unwrap());
if let (Ok(..), Ok(..)) = (File::open(&bindings_path), File::open(&binary_path)) {
return Ok(());
return Ok(())
};
let archive = download(out_dir)
@ -80,9 +80,10 @@ async fn download(directory: &std::path::Path) -> anyhow::Result<std::fs::File>
#[cfg(windows)]
fn parse(file: std::fs::File) -> anyhow::Result<(bindgen::Bindings, Vec<u8>)> {
use anyhow::Context;
use std::io::Read;
use anyhow::Context;
let reader = std::io::BufReader::new(file);
let mut archive = zip::ZipArchive::new(reader)?;

View file

@ -2,11 +2,11 @@
#[cfg(target_os = "windows")]
#[path = "windows/mod.rs"]
mod imp;
mod os_imp;
#[cfg(any(target_os = "linux", target_vendor = "apple"))]
#[path = "unix/mod.rs"]
pub(crate) mod imp;
pub(crate) mod os_imp;
mod options;
@ -14,5 +14,5 @@ mod options;
#[cfg(feature = "tokio")]
pub mod tokio;
pub use imp::{TunInterface, TunQueue};
pub use options::TunOptions;
pub use os_imp::{TunInterface, TunQueue};

View file

@ -1,17 +1,27 @@
use fehler::throws;
use std::io::Error;
use super::TunInterface;
use fehler::throws;
#[cfg(any(target_os = "linux", target_vendor = "apple"))]
#[cfg(feature = "tokio")]
use super::tokio::TunInterface;
#[derive(Debug, Clone, Default)]
#[cfg_attr(feature = "serde", derive(serde::Serialize, serde::Deserialize, schemars::JsonSchema))]
#[cfg_attr(
feature = "serde",
derive(serde::Serialize, serde::Deserialize, schemars::JsonSchema)
)]
pub struct TunOptions {
/// (Windows + Linux) Name the tun interface.
pub(crate) name: Option<String>,
pub name: Option<String>,
/// (Linux) Don't include packet information.
pub(crate) no_pi: Option<()>,
pub no_pi: bool,
/// (Linux) Avoid opening an existing persistant device.
pub(crate) tun_excl: Option<()>,
pub tun_excl: bool,
/// (Apple) Retrieve the tun interface
pub tun_retrieve: bool,
/// (Linux) The IP address of the tun interface.
pub address: Option<String>,
}
impl TunOptions {
@ -24,16 +34,26 @@ impl TunOptions {
self
}
pub fn no_pi(mut self, enable: bool) {
self.no_pi = enable.then_some(());
pub fn no_pi(mut self, enable: bool) -> Self {
self.no_pi = enable;
self
}
pub fn tun_excl(mut self, enable: bool) {
self.tun_excl = enable.then_some(());
pub fn tun_excl(mut self, enable: bool) -> Self {
self.tun_excl = enable;
self
}
pub fn address(mut self, address: impl ToString) -> Self {
self.address = Some(address.to_string());
self
}
#[cfg(any(target_os = "linux", target_vendor = "apple"))]
#[cfg(feature = "tokio")]
#[throws]
pub fn open(self) -> TunInterface {
TunInterface::new_with_options(self)?
let ti = super::TunInterface::new_with_options(self)?;
TunInterface::new(ti)?
}
}

View file

@ -1,22 +1,22 @@
use std::io;
use tokio::io::unix::AsyncFd;
use tracing::instrument;
#[derive(Debug)]
pub struct TunInterface {
inner: AsyncFd<crate::TunInterface>,
pub inner: AsyncFd<crate::TunInterface>,
}
impl TunInterface {
#[instrument]
pub fn new(tun: crate::TunInterface) -> io::Result<Self> {
Ok(Self {
inner: AsyncFd::new(tun)?,
})
pub fn new(mut tun: crate::TunInterface) -> io::Result<Self> {
tun.set_nonblocking(true)?;
Ok(Self { inner: AsyncFd::new(tun)? })
}
#[instrument]
pub async fn write(&self, buf: &[u8]) -> io::Result<usize> {
pub async fn send(&self, buf: &[u8]) -> io::Result<usize> {
loop {
let mut guard = self.inner.writable().await?;
match guard.try_io(|inner| inner.get_ref().send(buf)) {
@ -27,12 +27,15 @@ impl TunInterface {
}
#[instrument]
pub async fn read(&mut self, buf: &mut [u8]) -> io::Result<usize> {
pub async fn recv(&self, buf: &mut [u8]) -> io::Result<usize> {
loop {
let mut guard = self.inner.readable_mut().await?;
match guard.try_io(|inner| (*inner).get_mut().recv(buf)) {
let mut guard = self.inner.readable().await?;
match guard.try_io(|inner| inner.get_ref().recv(buf)) {
Ok(result) => return result,
Err(_would_block) => continue,
Err(_would_block) => {
tracing::debug!("WouldBlock");
continue
}
}
}
}

View file

@ -1,7 +1,6 @@
use std::{io::Error, mem::size_of, os::unix::io::AsRawFd};
use fehler::throws;
use std::io::Error;
use std::mem::size_of;
use std::os::unix::io::AsRawFd;
use super::sys;
@ -16,10 +15,7 @@ pub trait SysControlSocket {
impl SysControlSocket for socket2::Socket {
#[throws]
fn resolve(&self, name: &str, index: u32) -> socket2::SockAddr {
let mut info = sys::ctl_info {
ctl_id: 0,
ctl_name: [0; 96],
};
let mut info = sys::ctl_info { ctl_id: 0, ctl_name: [0; 96] };
info.ctl_name[..name.len()].copy_from_slice(name.as_bytes());
unsafe { sys::resolve_ctl_info(self.as_raw_fd(), &mut info as *mut sys::ctl_info)? };
@ -28,7 +24,7 @@ impl SysControlSocket for socket2::Socket {
socket2::SockAddr::init(|addr_storage, len| {
*len = size_of::<sys::sockaddr_ctl>() as u32;
let mut addr: &mut sys::sockaddr_ctl = &mut *addr_storage.cast();
let addr: &mut sys::sockaddr_ctl = &mut *addr_storage.cast();
addr.sc_len = *len as u8;
addr.sc_family = sys::AF_SYSTEM as u8;
addr.ss_sysaddr = sys::AF_SYS_CONTROL as u16;

View file

@ -1,21 +1,24 @@
use std::{
io::{Error, IoSlice},
mem,
net::{Ipv4Addr, SocketAddrV4},
os::fd::{AsRawFd, FromRawFd, RawFd},
};
use byteorder::{ByteOrder, NetworkEndian};
use fehler::throws;
use libc::{c_char, iovec, writev, AF_INET, AF_INET6};
use tracing::info;
use socket2::{Domain, SockAddr, Socket, Type};
use std::io::IoSlice;
use std::net::{Ipv4Addr, SocketAddrV4};
use std::os::fd::{AsRawFd, RawFd};
use std::{io::Error, mem};
use tracing::instrument;
use tracing::{self, instrument};
mod kern_control;
mod sys;
pub mod kern_control;
pub mod sys;
use kern_control::SysControlSocket;
pub use super::queue::TunQueue;
use super::{ifname_to_string, string_to_ifname, TunOptions};
use kern_control::SysControlSocket;
use super::{ifname_to_string, string_to_ifname};
use crate::TunOptions;
#[derive(Debug)]
pub struct TunInterface {
@ -31,8 +34,49 @@ impl TunInterface {
#[throws]
#[instrument]
pub fn new_with_options(_: TunOptions) -> TunInterface {
TunInterface::connect(0)?
pub fn new_with_options(options: TunOptions) -> TunInterface {
let ti = if options.tun_retrieve {
TunInterface::retrieve().ok_or(Error::new(
std::io::ErrorKind::NotFound,
"No tun interface found",
))?
} else {
TunInterface::connect(0)?
};
ti.configure(options)?;
ti
}
pub fn retrieve() -> Option<TunInterface> {
(3..100)
.filter_map(|fd| unsafe {
let peer_addr = socket2::SockAddr::init(|storage, len| {
*len = mem::size_of::<sys::sockaddr_ctl>() as u32;
libc::getpeername(fd, storage as *mut _, len);
Ok(())
})
.map(|(_, addr)| (fd, addr));
peer_addr.ok()
})
.filter(|(_fd, addr)| {
let ctl_addr = unsafe { &*(addr.as_ptr() as *const libc::sockaddr_ctl) };
addr.family() == libc::AF_SYSTEM as u8
&& ctl_addr.ss_sysaddr == libc::AF_SYS_CONTROL as u16
})
.map(|(fd, _)| {
let socket = unsafe { socket2::Socket::from_raw_fd(fd) };
TunInterface { socket }
})
.next()
}
#[throws]
fn configure(&self, options: TunOptions) {
if let Some(addr) = options.address {
if let Ok(addr) = addr.parse() {
self.set_ipv4_addr(addr)?;
}
}
}
#[throws]
@ -81,7 +125,7 @@ impl TunInterface {
let mut iff = self.ifreq()?;
iff.ifr_ifru.ifru_addr = unsafe { *addr.as_ptr() };
self.perform(|fd| unsafe { sys::if_set_addr(fd, &iff) })?;
info!("ipv4_addr_set: {:?} (fd: {:?})", addr, self.as_raw_fd())
tracing::info!("ipv4_addr_set: {:?} (fd: {:?})", addr, self.as_raw_fd())
}
#[throws]
@ -118,7 +162,7 @@ impl TunInterface {
let mut iff = self.ifreq()?;
iff.ifr_ifru.ifru_mtu = mtu;
self.perform(|fd| unsafe { sys::if_set_mtu(fd, &iff) })?;
info!("mtu_set: {:?} (fd: {:?})", mtu, self.as_raw_fd())
tracing::info!("mtu_set: {:?} (fd: {:?})", mtu, self.as_raw_fd())
}
#[throws]
@ -140,7 +184,7 @@ impl TunInterface {
let mut iff = self.ifreq()?;
iff.ifr_ifru.ifru_netmask = unsafe { *addr.as_ptr() };
self.perform(|fd| unsafe { sys::if_set_netmask(fd, &iff) })?;
info!(
tracing::info!(
"netmask_set: {:?} (fd: {:?})",
unsafe { iff.ifr_ifru.ifru_netmask },
self.as_raw_fd()

View file

@ -2,11 +2,20 @@ use std::mem;
use libc::{c_char, c_int, c_short, c_uint, c_ulong, sockaddr};
pub use libc::{
c_void, sockaddr_ctl, sockaddr_in, socklen_t, AF_SYSTEM, AF_SYS_CONTROL, IFNAMSIZ,
c_void,
sockaddr_ctl,
sockaddr_in,
socklen_t,
AF_SYSTEM,
AF_SYS_CONTROL,
IFNAMSIZ,
SYSPROTO_CONTROL,
};
use nix::{
ioctl_read_bad, ioctl_readwrite, ioctl_write_ptr_bad, request_code_readwrite,
ioctl_read_bad,
ioctl_readwrite,
ioctl_write_ptr_bad,
request_code_readwrite,
request_code_write,
};

View file

@ -1,18 +1,21 @@
use std::{
fs::OpenOptions,
io::{Error, Write},
mem,
net::{Ipv4Addr, Ipv6Addr, SocketAddrV4},
os::{
fd::RawFd,
unix::io::{AsRawFd, FromRawFd, IntoRawFd},
},
};
use fehler::throws;
use libc::in6_ifreq;
use socket2::{Domain, SockAddr, Socket, Type};
use std::fs::OpenOptions;
use std::io::{Error, Write};
use std::mem;
use std::net::{Ipv4Addr, Ipv6Addr, SocketAddrV4};
use std::os::fd::RawFd;
use std::os::unix::io::{AsRawFd, FromRawFd, IntoRawFd};
use tracing::{info, instrument};
use libc::in6_ifreq;
use super::{ifname_to_string, string_to_ifname, TunOptions};
use super::{ifname_to_string, string_to_ifname};
use crate::TunOptions;
mod sys;
@ -38,10 +41,10 @@ impl TunInterface {
let mut flags = libc::IFF_TUN as i16;
if options.no_pi.is_some() {
if options.no_pi {
flags |= libc::IFF_NO_PI as i16;
}
if options.tun_excl.is_some() {
if options.tun_excl {
flags |= libc::IFF_TUN_EXCL as i16;
}

View file

@ -1,10 +1,7 @@
use nix::{ioctl_read_bad, ioctl_write_ptr_bad, request_code_read, request_code_write};
use std::mem::size_of;
pub use libc::ifreq;
pub use libc::sockaddr;
pub use libc::sockaddr_in;
pub use libc::sockaddr_in6;
pub use libc::{ifreq, sockaddr, sockaddr_in, sockaddr_in6};
use nix::{ioctl_read_bad, ioctl_write_ptr_bad, request_code_read, request_code_write};
ioctl_write_ptr_bad!(
tun_set_iff,

View file

@ -1,10 +1,10 @@
use std::{
io::{Error, Read},
io::Error,
mem::MaybeUninit,
os::fd::{AsRawFd, FromRawFd, IntoRawFd, RawFd},
};
use tracing::instrument;
use super::TunOptions;
use tracing::instrument;
mod queue;
@ -28,9 +28,8 @@ impl AsRawFd for TunInterface {
impl FromRawFd for TunInterface {
unsafe fn from_raw_fd(fd: RawFd) -> TunInterface {
TunInterface {
socket: socket2::Socket::from_raw_fd(fd),
}
let socket = socket2::Socket::from_raw_fd(fd);
TunInterface { socket }
}
}
@ -40,11 +39,26 @@ impl IntoRawFd for TunInterface {
}
}
unsafe fn assume_init(buf: &[MaybeUninit<u8>]) -> &[u8] {
&*(buf as *const [MaybeUninit<u8>] as *const [u8])
}
impl TunInterface {
#[throws]
#[instrument]
pub fn recv(&mut self, buf: &mut [u8]) -> usize {
self.socket.read(buf)?
pub fn recv(&self, buf: &mut [u8]) -> usize {
// Use IoVec to read directly into target buffer
let mut tmp_buf = [MaybeUninit::uninit(); 1500];
let len = self.socket.recv(&mut tmp_buf)?;
let result_buf = unsafe { assume_init(&tmp_buf[4..len]) };
buf[..len - 4].copy_from_slice(result_buf);
len - 4
}
#[throws]
#[instrument]
pub fn set_nonblocking(&mut self, nb: bool) {
self.socket.set_nonblocking(nb)?;
}
}
@ -65,4 +79,4 @@ pub fn string_to_ifname(name: &str) -> [libc::c_char; libc::IFNAMSIZ] {
let len = name.len().min(buf.len());
buf[..len].copy_from_slice(unsafe { &*(name.as_bytes() as *const _ as *const [libc::c_char]) });
buf
}
}

View file

@ -1,10 +1,10 @@
use fehler::throws;
use std::{
io::{Error, Read, Write},
mem::MaybeUninit,
os::unix::io::{AsRawFd, IntoRawFd, RawFd},
};
use fehler::throws;
use tracing::instrument;
use crate::TunInterface;
@ -15,10 +15,9 @@ pub struct TunQueue {
}
impl TunQueue {
#[throws]
#[instrument]
pub fn recv(&self, buf: &mut [MaybeUninit<u8>]) -> usize {
self.socket.recv(buf)?
pub fn recv(&self, buf: &mut [MaybeUninit<u8>]) -> Result<usize, Error> {
self.socket.recv(buf)
}
}
@ -43,9 +42,7 @@ impl Write for TunQueue {
impl From<TunInterface> for TunQueue {
fn from(interface: TunInterface) -> TunQueue {
TunQueue {
socket: interface.socket,
}
TunQueue { socket: interface.socket }
}
}

View file

@ -1,15 +1,14 @@
use std::fmt::Debug;
use std::{fmt::Debug, io::Error, ptr};
use fehler::throws;
use std::io::Error;
use std::ptr;
use widestring::U16CString;
use windows::Win32::Foundation::GetLastError;
mod queue;
use super::TunOptions;
pub use queue::TunQueue;
use super::TunOptions;
pub struct TunInterface {
handle: sys::WINTUN_ADAPTER_HANDLE,
name: String,
@ -40,10 +39,7 @@ impl TunInterface {
if handle.is_null() {
unsafe { GetLastError() }.ok()?
}
TunInterface {
handle,
name: name_owned,
}
TunInterface { handle, name: name_owned }
}
pub fn name(&self) -> String {

View file

@ -1,6 +1,6 @@
use std::{io::Error, net::Ipv4Addr};
use fehler::throws;
use std::io::Error;
use std::net::Ipv4Addr;
use tun::TunInterface;
#[test]

View file

@ -1,7 +1,6 @@
use fehler::throws;
use std::io::Error;
use std::{io::Error, net::Ipv4Addr};
use std::net::Ipv4Addr;
use fehler::throws;
use tun::TunInterface;
#[throws]
@ -9,10 +8,10 @@ use tun::TunInterface;
#[ignore = "requires interactivity"]
#[cfg(not(target_os = "windows"))]
fn tst_read() {
// This test is interactive, you need to send a packet to any server through 192.168.1.10
// EG. `sudo route add 8.8.8.8 192.168.1.10`,
// This test is interactive, you need to send a packet to any server through
// 192.168.1.10 EG. `sudo route add 8.8.8.8 192.168.1.10`,
//`dig @8.8.8.8 hackclub.com`
let mut tun = TunInterface::new()?;
let tun = TunInterface::new()?;
println!("tun name: {:?}", tun.name()?);
tun.set_ipv4_addr(Ipv4Addr::from([192, 168, 1, 10]))?;
println!("tun ip: {:?}", tun.ipv4_addr()?);

View file

@ -4,7 +4,7 @@ use std::net::Ipv4Addr;
#[cfg(all(feature = "tokio", not(target_os = "windows")))]
async fn test_create() {
let tun = tun::TunInterface::new().unwrap();
let async_tun = tun::tokio::TunInterface::new(tun).unwrap();
let _ = tun::tokio::TunInterface::new(tun).unwrap();
}
#[tokio::test]
@ -17,6 +17,6 @@ async fn test_write() {
let async_tun = tun::tokio::TunInterface::new(tun).unwrap();
let mut buf = [0u8; 1500];
buf[0] = 6 << 4;
let bytes_written = async_tun.write(&buf).await.unwrap();
let bytes_written = async_tun.send(&buf).await.unwrap();
assert!(bytes_written > 0);
}