From 28af9003d06bfe54e704ec07ca1d0823251510c7 Mon Sep 17 00:00:00 2001 From: Conrad Kramer Date: Sat, 16 Sep 2023 10:45:53 -0700 Subject: [PATCH] merge boringtun into burrow --- .rustfmt.toml | 12 + .vscode/settings.json | 3 + Apple/Burrow.xcodeproj/project.pbxproj | 1 - .../xcshareddata/swiftpm/Package.resolved | 77 -- Cargo.lock | 325 ++++++- Cargo.toml | 1 + burrow/Cargo.toml | 20 +- burrow/src/daemon/instance.rs | 11 +- burrow/src/daemon/mod.rs | 45 +- burrow/src/daemon/net/mod.rs | 3 +- burrow/src/daemon/net/systemd.rs | 3 - burrow/src/daemon/net/unix.rs | 16 +- burrow/src/lib.rs | 2 +- burrow/src/main.rs | 24 +- burrow/src/wireguard/iface.rs | 132 +++ burrow/src/wireguard/mod.rs | 22 + burrow/src/wireguard/noise/errors.rs | 23 + burrow/src/wireguard/noise/handshake.rs | 901 ++++++++++++++++++ burrow/src/wireguard/noise/mod.rs | 609 ++++++++++++ burrow/src/wireguard/noise/rate_limiter.rs | 209 ++++ burrow/src/wireguard/noise/session.rs | 279 ++++++ burrow/src/wireguard/noise/timers.rs | 333 +++++++ burrow/src/wireguard/pcb.rs | 88 ++ burrow/src/wireguard/peer.rs | 23 + tun/build.rs | 5 +- tun/src/lib.rs | 6 +- tun/src/options.rs | 19 +- tun/src/tokio/mod.rs | 5 +- tun/src/unix/apple/kern_control.rs | 12 +- tun/src/unix/apple/mod.rs | 26 +- tun/src/unix/apple/sys.rs | 13 +- tun/src/unix/linux/mod.rs | 32 +- tun/src/unix/linux/sys.rs | 7 +- tun/src/unix/mod.rs | 8 +- tun/src/unix/queue.rs | 13 +- tun/src/windows/mod.rs | 21 +- tun/tests/configure.rs | 8 +- tun/tests/packets.rs | 9 +- tun/tests/tokio.rs | 4 +- 39 files changed, 3122 insertions(+), 228 deletions(-) create mode 100644 .rustfmt.toml delete mode 100644 Apple/Burrow.xcodeproj/project.xcworkspace/xcshareddata/swiftpm/Package.resolved create mode 100755 burrow/src/wireguard/iface.rs create mode 100755 burrow/src/wireguard/mod.rs create mode 100755 burrow/src/wireguard/noise/errors.rs create mode 100755 burrow/src/wireguard/noise/handshake.rs create mode 100755 burrow/src/wireguard/noise/mod.rs create mode 100755 burrow/src/wireguard/noise/rate_limiter.rs create mode 100755 burrow/src/wireguard/noise/session.rs create mode 100755 burrow/src/wireguard/noise/timers.rs create mode 100755 burrow/src/wireguard/pcb.rs create mode 100755 burrow/src/wireguard/peer.rs diff --git a/.rustfmt.toml b/.rustfmt.toml new file mode 100644 index 0000000..2a12e19 --- /dev/null +++ b/.rustfmt.toml @@ -0,0 +1,12 @@ +condense_wildcard_suffixes = true +format_macro_matchers = true +imports_layout = "HorizontalVertical" +imports_granularity = "Crate" +newline_style = "Unix" +overflow_delimited_expr = true +reorder_impl_items = true +group_imports = "StdExternalCrate" +trailing_semicolon = false +use_field_init_shorthand = true +use_try_shorthand = true +struct_lit_width = 30 diff --git a/.vscode/settings.json b/.vscode/settings.json index 4718093..5fbfc5c 100644 --- a/.vscode/settings.json +++ b/.vscode/settings.json @@ -8,6 +8,9 @@ "editor.acceptSuggestionOnEnter": "on", "rust-analyzer.restartServerOnConfigChange": true, "rust-analyzer.cargo.features": "all", + "rust-analyzer.rustfmt.extraArgs": [ + "+nightly" + ], "[rust]": { "editor.defaultFormatter": "rust-lang.rust-analyzer", } diff --git a/Apple/Burrow.xcodeproj/project.pbxproj b/Apple/Burrow.xcodeproj/project.pbxproj index f9c7454..8cdc60b 100644 --- a/Apple/Burrow.xcodeproj/project.pbxproj +++ b/Apple/Burrow.xcodeproj/project.pbxproj @@ -245,7 +245,6 @@ ); mainGroup = D05B9F6929E39EEC008CB1F9; packageReferences = ( - D0BCC6102A0B327700AD070D /* XCRemoteSwiftPackageReference "SwiftLint" */, ); productRefGroup = D05B9F7329E39EEC008CB1F9 /* Products */; projectDirPath = ""; diff --git a/Apple/Burrow.xcodeproj/project.xcworkspace/xcshareddata/swiftpm/Package.resolved b/Apple/Burrow.xcodeproj/project.xcworkspace/xcshareddata/swiftpm/Package.resolved deleted file mode 100644 index 233bbf9..0000000 --- a/Apple/Burrow.xcodeproj/project.xcworkspace/xcshareddata/swiftpm/Package.resolved +++ /dev/null @@ -1,77 +0,0 @@ -{ - "pins" : [ - { - "identity" : "collectionconcurrencykit", - "kind" : "remoteSourceControl", - "location" : "https://github.com/JohnSundell/CollectionConcurrencyKit.git", - "state" : { - "revision" : "b4f23e24b5a1bff301efc5e70871083ca029ff95", - "version" : "0.2.0" - } - }, - { - "identity" : "sourcekitten", - "kind" : "remoteSourceControl", - "location" : "https://github.com/jpsim/SourceKitten.git", - "state" : { - "revision" : "b6dc09ee51dfb0c66e042d2328c017483a1a5d56", - "version" : "0.34.1" - } - }, - { - "identity" : "swift-argument-parser", - "kind" : "remoteSourceControl", - "location" : "https://github.com/apple/swift-argument-parser.git", - "state" : { - "revision" : "fee6933f37fde9a5e12a1e4aeaa93fe60116ff2a", - "version" : "1.2.2" - } - }, - { - "identity" : "swift-syntax", - "kind" : "remoteSourceControl", - "location" : "https://github.com/apple/swift-syntax.git", - "state" : { - "revision" : "013a48e2312e57b7b355db25bd3ea75282ebf274", - "version" : "0.50900.0-swift-DEVELOPMENT-SNAPSHOT-2023-02-06-a" - } - }, - { - "identity" : "swiftlint", - "kind" : "remoteSourceControl", - "location" : "https://github.com/realm/SwiftLint.git", - "state" : { - "revision" : "eb85125a5f293de3d3248af259980c98bc2b1faa", - "version" : "0.51.0" - } - }, - { - "identity" : "swiftytexttable", - "kind" : "remoteSourceControl", - "location" : "https://github.com/scottrhoyt/SwiftyTextTable.git", - "state" : { - "revision" : "c6df6cf533d120716bff38f8ff9885e1ce2a4ac3", - "version" : "0.9.0" - } - }, - { - "identity" : "swxmlhash", - "kind" : "remoteSourceControl", - "location" : "https://github.com/drmohundro/SWXMLHash.git", - "state" : { - "revision" : "4d0f62f561458cbe1f732171e625f03195151b60", - "version" : "7.0.1" - } - }, - { - "identity" : "yams", - "kind" : "remoteSourceControl", - "location" : "https://github.com/jpsim/Yams.git", - "state" : { - "revision" : "f47ba4838c30dbd59998a4e4c87ab620ff959e8a", - "version" : "5.0.5" - } - } - ], - "version" : 2 -} diff --git a/Cargo.lock b/Cargo.lock index f7cf03a..ff7c601 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -17,6 +17,16 @@ version = "1.0.2" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "f26201604c87b1e01bd3d98f8d5d9a8fcbb815e8cedb41ffccbeb4bf593a35fe" +[[package]] +name = "aead" +version = "0.5.2" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "d122413f284cf2d62fb1b7db97e02edb8cda96d769b16e443a4f6195e35662b0" +dependencies = [ + "crypto-common", + "generic-array", +] + [[package]] name = "aes" version = "0.8.3" @@ -92,6 +102,23 @@ version = "1.0.71" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "9c7d0618f0e0b7e8ff11427422b64564d5fb0be1940354bfe2e0529b18a9d9b8" +[[package]] +name = "arrayvec" +version = "0.7.4" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "96d30a06541fbafbc7f82ed10c06164cfbd2c401138f6addd8404629c4b16711" + +[[package]] +name = "async-trait" +version = "0.1.74" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "a66537f1bb974b254c98ed142ff995236e81b9d0fe4db0575f46612cb15eb0f9" +dependencies = [ + "proc-macro2", + "quote", + "syn 2.0.32", +] + [[package]] name = "autocfg" version = "1.1.0" @@ -115,9 +142,9 @@ dependencies = [ [[package]] name = "base64" -version = "0.21.2" +version = "0.21.4" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "604178f6c5c21f02dc555784810edfb88d34ac2c73b2eae109655649ee73ce3d" +checksum = "9ba43ea6f343b788c8764558649e08df62f86c6ef251fdaeb1ffd010a9ae50a2" [[package]] name = "base64ct" @@ -166,7 +193,7 @@ dependencies = [ "regex", "rustc-hash", "shlex", - "syn 2.0.22", + "syn 2.0.32", "which", ] @@ -182,6 +209,15 @@ version = "2.3.3" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "630be753d4e58660abd17930c71b647fe46c27ea6b63cc59e1e3851406972e42" +[[package]] +name = "blake2" +version = "0.10.6" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "46502ad458c9a52b69d4d4d32775c788b7a1b85e8bc9d482d92250fc0e3f8efe" +dependencies = [ + "digest", +] + [[package]] name = "block-buffer" version = "0.10.4" @@ -201,13 +237,28 @@ checksum = "a3e2c3daef883ecc1b5d58c15adae93470a91d425f3532ba1695849656af3fc1" name = "burrow" version = "0.1.0" dependencies = [ + "aead", "anyhow", + "async-trait", + "base64", + "blake2", "caps", + "chacha20poly1305", "clap", "env_logger", + "etherparse", + "fehler", + "hmac", + "ip_network", + "ip_network_table", + "ipnet", "libsystemd", "log", "nix", + "parking_lot", + "rand", + "rand_core", + "ring", "serde", "serde_json", "tokio", @@ -217,6 +268,7 @@ dependencies = [ "tracing-oslog", "tracing-subscriber", "tun", + "x25519-dalek", ] [[package]] @@ -286,6 +338,30 @@ version = "1.0.0" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "baf1de4339761588bc0619e3cbc0120ee582ebb74b53b4efbf79117bd2da40fd" +[[package]] +name = "chacha20" +version = "0.9.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "c3613f74bd2eac03dad61bd53dbe620703d4371614fe0bc3b9f04dd36fe4e818" +dependencies = [ + "cfg-if", + "cipher", + "cpufeatures", +] + +[[package]] +name = "chacha20poly1305" +version = "0.10.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "10cd79432192d1c0f4e1a0fef9527696cc039165d729fb41b3f4f4f354c2dc35" +dependencies = [ + "aead", + "chacha20", + "cipher", + "poly1305", + "zeroize", +] + [[package]] name = "cipher" version = "0.4.4" @@ -294,6 +370,7 @@ checksum = "773f3b9af64447d2ce9850330c473515014aa235e6a783b02db81ff39e4a3dad" dependencies = [ "crypto-common", "inout", + "zeroize", ] [[package]] @@ -339,7 +416,7 @@ dependencies = [ "heck", "proc-macro2", "quote", - "syn 2.0.22", + "syn 2.0.32", ] [[package]] @@ -410,9 +487,37 @@ source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "1bfb12502f3fc46cca1bb51ac28df9d618d813cdc3d2f25b9fe775a34af26bb3" dependencies = [ "generic-array", + "rand_core", "typenum", ] +[[package]] +name = "curve25519-dalek" +version = "4.1.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "622178105f911d937a42cdb140730ba4a3ed2becd8ae6ce39c7d28b5d75d4588" +dependencies = [ + "cfg-if", + "cpufeatures", + "curve25519-dalek-derive", + "fiat-crypto", + "platforms", + "rustc_version", + "subtle", + "zeroize", +] + +[[package]] +name = "curve25519-dalek-derive" +version = "0.1.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "83fdaf97f4804dcebfa5862639bc9ce4121e82140bec2a987ac5140294865b5b" +dependencies = [ + "proc-macro2", + "quote", + "syn 2.0.32", +] + [[package]] name = "digest" version = "0.10.7" @@ -473,6 +578,15 @@ dependencies = [ "libc", ] +[[package]] +name = "etherparse" +version = "0.12.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "bcb08c4aab4e2985045305551e67126b43f1b6b136bc4e1cd87fb0327877a611" +dependencies = [ + "arrayvec", +] + [[package]] name = "fastrand" version = "1.9.0" @@ -502,6 +616,12 @@ dependencies = [ "syn 1.0.109", ] +[[package]] +name = "fiat-crypto" +version = "0.2.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "d0870c84016d4b481be5c9f323c24f65e31e901ae618f0e80f4308fb00de1d2d" + [[package]] name = "flate2" version = "1.0.26" @@ -598,7 +718,7 @@ checksum = "89ca545a94061b6365f2c7355b4b32bd20df3ff95f02da9329b34ccc3bd6ee72" dependencies = [ "proc-macro2", "quote", - "syn 2.0.22", + "syn 2.0.32", ] [[package]] @@ -641,6 +761,17 @@ dependencies = [ "version_check", ] +[[package]] +name = "getrandom" +version = "0.2.10" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "be4136b2a15dd319360be1c07d9933517ccf0be8f16bf62a3bee4f0d618df427" +dependencies = [ + "cfg-if", + "libc", + "wasi", +] + [[package]] name = "gimli" version = "0.27.3" @@ -831,11 +962,36 @@ dependencies = [ "windows-sys 0.48.0", ] +[[package]] +name = "ip_network" +version = "0.4.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "aa2f047c0a98b2f299aa5d6d7088443570faae494e9ae1305e48be000c9e0eb1" + +[[package]] +name = "ip_network_table" +version = "0.2.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "4099b7cfc5c5e2fe8c5edf3f6f7adf7a714c9cc697534f63a5a5da30397cb2c0" +dependencies = [ + "ip_network", + "ip_network_table-deps-treebitmap", +] + +[[package]] +name = "ip_network_table-deps-treebitmap" +version = "0.5.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "8e537132deb99c0eb4b752f0346b6a836200eaaa3516dd7e5514b63930a09e5d" + [[package]] name = "ipnet" version = "2.8.0" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "28b29a3cd74f0f4598934efe3aeba42bae0eb4680554128851ebbecb02af14e6" +dependencies = [ + "serde", +] [[package]] name = "is-terminal" @@ -981,7 +1137,7 @@ checksum = "4901771e1d44ddb37964565c654a3223ba41a594d02b8da471cc4464912b5cfa" dependencies = [ "proc-macro2", "quote", - "syn 2.0.22", + "syn 2.0.32", ] [[package]] @@ -1083,6 +1239,12 @@ version = "1.18.0" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "dd8b5dd2ae5ed71462c540258bedcb51965123ad7e7ccf4b9a8cafaa4a63576d" +[[package]] +name = "opaque-debug" +version = "0.3.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "624a8340c38c1b80fd549087862da4ba43e08858af025b236e509b6649fc13d5" + [[package]] name = "openssl" version = "0.10.55" @@ -1106,7 +1268,7 @@ checksum = "a948666b637a0f465e8564c73e89d4dde00d72d4d473cc972f390fc3dcee7d9c" dependencies = [ "proc-macro2", "quote", - "syn 2.0.22", + "syn 2.0.32", ] [[package]] @@ -1209,6 +1371,29 @@ version = "0.3.27" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "26072860ba924cbfa98ea39c8c19b4dd6a4a25423dbdf219c1eca91aa0cf6964" +[[package]] +name = "platforms" +version = "3.1.2" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "4503fa043bf02cee09a9582e9554b4c6403b2ef55e4612e96561d294419429f8" + +[[package]] +name = "poly1305" +version = "0.8.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "8159bd90725d2df49889a078b54f4f79e87f1f8a8444194cdca81d38f5393abf" +dependencies = [ + "cpufeatures", + "opaque-debug", + "universal-hash", +] + +[[package]] +name = "ppv-lite86" +version = "0.2.17" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "5b40af805b3121feab8a3c29f04d8ad262fa8e0561883e7653e024ae4479e6de" + [[package]] name = "prettyplease" version = "0.2.9" @@ -1216,7 +1401,7 @@ source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "9825a04601d60621feed79c4e6b56d65db77cdca55cef43b46b0de1096d1c282" dependencies = [ "proc-macro2", - "syn 2.0.22", + "syn 2.0.32", ] [[package]] @@ -1237,11 +1422,35 @@ dependencies = [ "proc-macro2", ] +[[package]] +name = "rand" +version = "0.8.5" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "34af8d1a0e25924bc5b7c43c079c942339d8f0a8b57c39049bef581b46327404" +dependencies = [ + "libc", + "rand_chacha", + "rand_core", +] + +[[package]] +name = "rand_chacha" +version = "0.3.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "e6c10a63a0fa32252be49d21e7709d4d4baf8d231c2dbce1eaa8141b9b127d88" +dependencies = [ + "ppv-lite86", + "rand_core", +] + [[package]] name = "rand_core" version = "0.6.4" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "ec0be4795e2f6a28069bec0b5ff3e2ac9bafc99e6a9a7dc3547996c5c816922c" +dependencies = [ + "getrandom", +] [[package]] name = "redox_syscall" @@ -1306,6 +1515,21 @@ dependencies = [ "winreg", ] +[[package]] +name = "ring" +version = "0.16.20" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "3053cf52e236a3ed746dfc745aa9cacf1b791d846bdaf412f60a8d7d6e17c8fc" +dependencies = [ + "cc", + "libc", + "once_cell", + "spin", + "untrusted", + "web-sys", + "winapi", +] + [[package]] name = "rustc-demangle" version = "0.1.23" @@ -1318,6 +1542,15 @@ version = "1.1.0" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "08d43f7aa6b08d49f382cde6a7982047c3426db949b1424bc4b7ec9ae12c6ce2" +[[package]] +name = "rustc_version" +version = "0.4.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "bfa0f585226d2e68097d4f95d113b15b83a82e819ab25717ec0590d9584ef366" +dependencies = [ + "semver", +] + [[package]] name = "rustix" version = "0.37.21" @@ -1389,6 +1622,12 @@ dependencies = [ "libc", ] +[[package]] +name = "semver" +version = "1.0.18" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "b0293b4b29daaf487284529cc2f5675b8e57c61f70167ba415a463651fd6a918" + [[package]] name = "serde" version = "1.0.164" @@ -1406,7 +1645,7 @@ checksum = "d9735b638ccc51c28bf6914d90a2e9725b377144fc612c49a611fddd1b631d68" dependencies = [ "proc-macro2", "quote", - "syn 2.0.22", + "syn 2.0.32", ] [[package]] @@ -1505,6 +1744,12 @@ dependencies = [ "winapi", ] +[[package]] +name = "spin" +version = "0.5.2" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "6e63cff320ae2c57904679ba7cb63280a3dc4613885beafb148ee7bf9aa9042d" + [[package]] name = "ssri" version = "9.0.0" @@ -1552,9 +1797,9 @@ dependencies = [ [[package]] name = "syn" -version = "2.0.22" +version = "2.0.32" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "2efbeae7acf4eabd6bcdcbd11c92f45231ddda7539edc7806bd1a04a03b24616" +checksum = "239814284fd6f1a4ffe4ca893952cdd93c224b6a1571c9a9eadd670295c0c9e2" dependencies = [ "proc-macro2", "quote", @@ -1601,7 +1846,7 @@ checksum = "f9456a42c5b0d803c8cd86e73dd7cc9edd429499f37a3550d286d5e86720569f" dependencies = [ "proc-macro2", "quote", - "syn 2.0.22", + "syn 2.0.32", ] [[package]] @@ -1670,7 +1915,7 @@ checksum = "630bdcf245f78637c13ec01ffae6187cca34625e8c63150d424b59e55af2675e" dependencies = [ "proc-macro2", "quote", - "syn 2.0.22", + "syn 2.0.32", ] [[package]] @@ -1723,7 +1968,7 @@ checksum = "5f4f31f56159e98206da9efd823404b79b6ef3143b4a7ab76e67b1751b25a4ab" dependencies = [ "proc-macro2", "quote", - "syn 2.0.22", + "syn 2.0.32", ] [[package]] @@ -1852,6 +2097,22 @@ version = "0.1.10" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "c0edd1e5b14653f783770bce4a4dabb4a5108a5370a5f5d8cfe8710c361f6c8b" +[[package]] +name = "universal-hash" +version = "0.5.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "fc1de2c688dc15305988b563c3854064043356019f97a4b46276fe734c4f07ea" +dependencies = [ + "crypto-common", + "subtle", +] + +[[package]] +name = "untrusted" +version = "0.7.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "a156c684c91ea7d62626509bce3cb4e1d9ed5c4d978f7b4352658f96a4c26b4a" + [[package]] name = "url" version = "2.4.0" @@ -1932,7 +2193,7 @@ dependencies = [ "once_cell", "proc-macro2", "quote", - "syn 2.0.22", + "syn 2.0.32", "wasm-bindgen-shared", ] @@ -1966,7 +2227,7 @@ checksum = "54681b18a46765f095758388f2d0cf16eb8d4169b639ab575a8f5693af210c7b" dependencies = [ "proc-macro2", "quote", - "syn 2.0.22", + "syn 2.0.32", "wasm-bindgen-backend", "wasm-bindgen-shared", ] @@ -2176,12 +2437,44 @@ dependencies = [ "winapi", ] +[[package]] +name = "x25519-dalek" +version = "2.0.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "fb66477291e7e8d2b0ff1bcb900bf29489a9692816d79874bea351e7a8b6de96" +dependencies = [ + "curve25519-dalek", + "rand_core", + "serde", + "zeroize", +] + [[package]] name = "xxhash-rust" version = "0.8.6" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "735a71d46c4d68d71d4b24d03fdc2b98e38cea81730595801db779c04fe80d70" +[[package]] +name = "zeroize" +version = "1.6.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "2a0956f1ba7c7909bfb66c2e9e4124ab6f6482560f6628b5aaeba39207c9aad9" +dependencies = [ + "zeroize_derive", +] + +[[package]] +name = "zeroize_derive" +version = "1.4.2" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "ce36e65b0d2999d2aafac989fb249189a141aee1f53c612c1f37d72631959f69" +dependencies = [ + "proc-macro2", + "quote", + "syn 2.0.32", +] + [[package]] name = "zip" version = "0.6.6" diff --git a/Cargo.toml b/Cargo.toml index fcb83f5..3452869 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -1,2 +1,3 @@ [workspace] members = ["burrow", "tun"] +resolver = "2" diff --git a/burrow/Cargo.toml b/burrow/Cargo.toml index c9d0e67..06519de 100644 --- a/burrow/Cargo.toml +++ b/burrow/Cargo.toml @@ -9,7 +9,7 @@ crate-type = ["lib", "staticlib"] [dependencies] anyhow = "1.0" tokio = { version = "1.21", features = ["rt", "macros", "sync", "io-util"] } -tun = { version = "0.1", path = "../tun", features = ["serde"] } +tun = { version = "0.1", path = "../tun", features = ["serde", "tokio"] } clap = { version = "4.3.2", features = ["derive"] } tracing = "0.1" tracing-log = "0.1" @@ -20,6 +20,21 @@ env_logger = "0.10" log = "0.4" serde = { version = "1", features = ["derive"] } serde_json = "1" +blake2 = "0.10.6" +chacha20poly1305 = "0.10.1" +rand = "0.8.5" +rand_core = "0.6.4" +aead = "0.5.2" +x25519-dalek = { version = "2.0.0", features = ["reusable_secrets", "static_secrets"] } +ring = "0.16.20" +parking_lot = "0.12.1" +hmac = "0.12" +ipnet = { version = "2.8.0", features = ["serde"] } +base64 = "0.21.4" +fehler = "1.0.0" +ip_network_table = "0.2.0" +ip_network = "0.4.0" +async-trait = "0.1.74" [target.'cfg(target_os = "linux")'.dependencies] caps = "0.5.5" @@ -27,3 +42,6 @@ libsystemd = "0.6" [target.'cfg(target_vendor = "apple")'.dependencies] nix = { version = "0.26.2" } + +[dev-dependencies] +etherparse = "0.12" diff --git a/burrow/src/daemon/instance.rs b/burrow/src/daemon/instance.rs index d1849d0..efcd2bf 100644 --- a/burrow/src/daemon/instance.rs +++ b/burrow/src/daemon/instance.rs @@ -1,4 +1,8 @@ -use super::*; +use anyhow::Result; +use tokio::sync::mpsc; +use tun::TunInterface; + +use super::DaemonCommand; pub struct DaemonInstance { rx: mpsc::Receiver, @@ -7,10 +11,7 @@ pub struct DaemonInstance { impl DaemonInstance { pub fn new(rx: mpsc::Receiver) -> Self { - Self { - rx, - tun_interface: None, - } + Self { rx, tun_interface: None } } pub async fn run(&mut self) -> Result<()> { diff --git a/burrow/src/daemon/mod.rs b/burrow/src/daemon/mod.rs index 5fcf8ee..2fa09be 100644 --- a/burrow/src/daemon/mod.rs +++ b/burrow/src/daemon/mod.rs @@ -1,19 +1,52 @@ -use super::*; +use std::net::SocketAddr; + use tokio::sync::mpsc; mod command; mod instance; mod net; -use instance::DaemonInstance; -use net::listen; - +use anyhow::Error; +use base64::{engine::general_purpose, Engine as _}; +use burrow::wireguard::{Interface, Peer, PublicKey, StaticSecret}; pub use command::{DaemonCommand, DaemonStartOptions}; +use fehler::throws; +use instance::DaemonInstance; pub use net::DaemonClient; -pub async fn daemon_main() -> Result<()> { +#[throws] +fn parse_secret_key(string: &str) -> StaticSecret { + let value = general_purpose::STANDARD.decode(string)?; + let mut key = [0u8; 32]; + key.copy_from_slice(&value[..]); + StaticSecret::from(key) +} + +#[throws] +fn parse_public_key(string: &str) -> PublicKey { + let value = general_purpose::STANDARD.decode(string)?; + let mut key = [0u8; 32]; + key.copy_from_slice(&value[..]); + PublicKey::from(key) +} + +pub async fn daemon_main() -> anyhow::Result<()> { let (tx, rx) = mpsc::channel(2); let mut inst = DaemonInstance::new(rx); + // tokio::try_join!(inst.run(), listen(tx)).map(|_| ()) - tokio::try_join!(inst.run(), listen(tx)).map(|_| ()) + let tun = tun::tokio::TunInterface::new(tun::TunInterface::new()?)?; + + let private_key = parse_secret_key("sIxpokQPnWctJKNaQ3DRdcQbL2S5OMbUrvr4bbsvTHw=")?; + let public_key = parse_public_key("EKZXvHlSDeqAjfC/m9aQR0oXfQ6Idgffa9L0DH5yaCo=")?; + let endpoint = "146.70.173.66:51820".parse::()?; + let iface = Interface::new(tun, vec![Peer { + endpoint, + private_key, + public_key, + allowed_ips: vec![], + }])?; + + iface.run().await; + Ok(()) } diff --git a/burrow/src/daemon/net/mod.rs b/burrow/src/daemon/net/mod.rs index d8cc5fa..5eb7c34 100644 --- a/burrow/src/daemon/net/mod.rs +++ b/burrow/src/daemon/net/mod.rs @@ -1,6 +1,7 @@ -use super::*; use serde::{Deserialize, Serialize}; +use super::DaemonCommand; + #[cfg(target_family = "unix")] mod unix; #[cfg(all(target_family = "unix", not(target_os = "linux")))] diff --git a/burrow/src/daemon/net/systemd.rs b/burrow/src/daemon/net/systemd.rs index f67888e..5c89a4e 100644 --- a/burrow/src/daemon/net/systemd.rs +++ b/burrow/src/daemon/net/systemd.rs @@ -1,6 +1,3 @@ -use super::*; -use std::os::fd::IntoRawFd; - pub async fn listen(cmd_tx: mpsc::Sender) -> Result<()> { if !libsystemd::daemon::booted() || listen_with_systemd(cmd_tx.clone()).await.is_err() { unix::listen(cmd_tx).await?; diff --git a/burrow/src/daemon/net/unix.rs b/burrow/src/daemon/net/unix.rs index a2837a3..c5254d9 100644 --- a/burrow/src/daemon/net/unix.rs +++ b/burrow/src/daemon/net/unix.rs @@ -1,14 +1,21 @@ -use super::*; use std::{ - os::fd::{FromRawFd, RawFd}, - os::unix::net::UnixListener as StdUnixListener, + os::{ + fd::{FromRawFd, RawFd}, + unix::net::UnixListener as StdUnixListener, + }, path::Path, }; + +use anyhow::Result; use tokio::{ io::{AsyncBufReadExt, AsyncWriteExt, BufReader}, net::{UnixListener, UnixStream}, + sync::mpsc, }; +use super::{DaemonRequest, DaemonResponse}; +use crate::daemon::DaemonCommand; + const UNIX_SOCKET_PATH: &str = "/run/burrow.sock"; pub async fn listen(cmd_tx: mpsc::Sender) -> Result<()> { @@ -40,7 +47,8 @@ pub(crate) async fn listen_with_optional_fd( let cmd_tx = cmd_tx.clone(); // I'm pretty sure we won't need to manually join / shut this down, - // `lines` will return Err during dropping, and this task should exit gracefully. + // `lines` will return Err during dropping, and this task should exit + // gracefully. tokio::task::spawn(async { let cmd_tx = cmd_tx; let mut stream = stream; diff --git a/burrow/src/lib.rs b/burrow/src/lib.rs index 1032e97..3c1456d 100644 --- a/burrow/src/lib.rs +++ b/burrow/src/lib.rs @@ -1,5 +1,5 @@ -#![deny(missing_debug_implementations)] pub mod ensureroot; +pub mod wireguard; #[cfg(any(target_os = "linux", target_vendor = "apple"))] use std::{ diff --git a/burrow/src/main.rs b/burrow/src/main.rs index 1f70b1c..483b459 100644 --- a/burrow/src/main.rs +++ b/burrow/src/main.rs @@ -1,17 +1,15 @@ -use anyhow::Context; use std::mem; #[cfg(any(target_os = "linux", target_vendor = "apple"))] use std::os::fd::FromRawFd; +use anyhow::{Context, Result}; +#[cfg(any(target_os = "linux", target_vendor = "apple"))] +use burrow::retrieve; use clap::{Args, Parser, Subcommand}; use tracing::instrument; - use tracing_log::LogTracer; use tracing_oslog::OsLogger; use tracing_subscriber::{prelude::*, FmtSubscriber}; -use tokio::io::Result; -#[cfg(any(target_os = "linux", target_vendor = "apple"))] -use burrow::retrieve; use tun::TunInterface; mod daemon; @@ -66,13 +64,17 @@ async fn try_start() -> Result<()> { #[cfg(any(target_os = "linux", target_vendor = "apple"))] #[instrument] async fn try_retrieve() -> Result<()> { - LogTracer::init().context("Failed to initialize LogTracer").unwrap(); + LogTracer::init() + .context("Failed to initialize LogTracer") + .unwrap(); if cfg!(target_os = "linux") || cfg!(target_vendor = "apple") { let maybe_layer = system_log().unwrap(); if let Some(layer) = maybe_layer { let logger = layer.with_subscriber(FmtSubscriber::new()); - tracing::subscriber::set_global_default(logger).context("Failed to set the global tracing subscriber").unwrap(); + tracing::subscriber::set_global_default(logger) + .context("Failed to set the global tracing subscriber") + .unwrap(); } } @@ -128,18 +130,18 @@ async fn main() -> Result<()> { } #[cfg(target_os = "linux")] -fn system_log() -> anyhow::Result> { +fn system_log() -> Result> { let maybe_journald = tracing_journald::layer(); match maybe_journald { Err(e) if e.kind() == std::io::ErrorKind::NotFound => { tracing::trace!("journald not found"); Ok(None) - }, - _ => Ok(Some(maybe_journald?)) + } + _ => Ok(Some(maybe_journald?)), } } #[cfg(target_vendor = "apple")] -fn system_log() -> anyhow::Result> { +fn system_log() -> Result> { Ok(Some(OsLogger::new("com.hackclub.burrow", "default"))) } diff --git a/burrow/src/wireguard/iface.rs b/burrow/src/wireguard/iface.rs new file mode 100755 index 0000000..ede3424 --- /dev/null +++ b/burrow/src/wireguard/iface.rs @@ -0,0 +1,132 @@ +use std::{net::IpAddr, rc::Rc}; + +use anyhow::Error; +use async_trait::async_trait; +use fehler::throws; +use ip_network_table::IpNetworkTable; +use tokio::{ + join, + sync::Mutex, + task::{self, JoinHandle}, +}; +use tun::tokio::TunInterface; + +use super::{noise::Tunnel, pcb, Peer, PeerPcb}; + +#[async_trait] +pub trait PacketInterface { + async fn recv(&mut self, buf: &mut [u8]) -> Result; + async fn send(&mut self, buf: &[u8]) -> Result; +} + +#[async_trait] +impl PacketInterface for tun::tokio::TunInterface { + async fn recv(&mut self, buf: &mut [u8]) -> Result { + self.recv(buf).await + } + + async fn send(&mut self, buf: &[u8]) -> Result { + self.send(buf).await + } +} + +struct IndexedPcbs { + pcbs: Vec, + allowed_ips: IpNetworkTable, +} + +impl IndexedPcbs { + pub fn new() -> Self { + Self { + pcbs: vec![], + allowed_ips: IpNetworkTable::new(), + } + } + + pub fn insert(&mut self, pcb: PeerPcb) { + let idx: usize = self.pcbs.len(); + for allowed_ip in pcb.allowed_ips.iter() { + self.allowed_ips.insert(allowed_ip.clone(), idx); + } + self.pcbs.insert(idx, pcb); + } + + pub fn find(&mut self, addr: IpAddr) -> Option { + let (_, &idx) = self.allowed_ips.longest_match(addr)?; + Some(idx) + } + + pub fn connect(&mut self, idx: usize, handle: JoinHandle<()>) { + self.pcbs[idx].handle = Some(handle); + } +} + +impl FromIterator for IndexedPcbs { + fn from_iter>(iter: I) -> Self { + iter.into_iter().fold(Self::new(), |mut acc, pcb| { + acc.insert(pcb); + acc + }) + } +} + +pub struct Interface { + tun: Rc>, + pcbs: Rc>, +} + +impl Interface { + #[throws] + pub fn new>(tun: TunInterface, peers: I) -> Self { + let pcbs: IndexedPcbs = peers + .into_iter() + .map(|peer| PeerPcb::new(peer)) + .collect::>()?; + + let tun = Rc::new(Mutex::new(tun)); + let pcbs = Rc::new(Mutex::new(pcbs)); + Self { tun, pcbs } + } + + pub async fn run(self) { + let pcbs = self.pcbs; + let tun = self.tun; + + let outgoing = async move { + loop { + let mut buf = [0u8; 3000]; + + let mut tun = tun.lock().await; + let src = match tun.recv(&mut buf[..]).await { + Ok(len) => &buf[..len], + Err(e) => { + log::error!("failed reading from interface: {}", e); + continue + } + }; + + let mut pcbs = pcbs.lock().await; + + let dst_addr = match Tunnel::dst_address(src) { + Some(addr) => addr, + None => continue, + }; + + let Some(idx) = pcbs.find(dst_addr) else { + continue + }; + match pcbs.pcbs[idx].send(src).await { + Ok(..) => {} + Err(e) => log::error!("failed to send packet {}", e), + } + } + }; + + task::LocalSet::new() + .run_until(async move { + let outgoing = task::spawn_local(outgoing); + join!(outgoing); + }) + .await; + } +} diff --git a/burrow/src/wireguard/mod.rs b/burrow/src/wireguard/mod.rs new file mode 100755 index 0000000..4383ff8 --- /dev/null +++ b/burrow/src/wireguard/mod.rs @@ -0,0 +1,22 @@ +mod iface; +mod noise; +mod pcb; +mod peer; + +pub use iface::Interface; +pub use pcb::PeerPcb; +pub use peer::Peer; +pub use x25519_dalek::{PublicKey, StaticSecret}; + +const WIREGUARD_CONFIG: &str = r#" +[Interface] +# Device: Gentle Tomcat +PrivateKey = sIxpokQPnWctJKNaQ3DRdcQbL2S5OMbUrvr4bbsvTHw= +Address = 10.68.136.199/32,fc00:bbbb:bbbb:bb01::5:88c6/128 +DNS = 10.64.0.1 + +[Peer] +PublicKey = EKZXvHlSDeqAjfC/m9aQR0oXfQ6Idgffa9L0DH5yaCo= +AllowedIPs = 0.0.0.0/0,::0/0 +Endpoint = 146.70.173.66:51820 +"#; diff --git a/burrow/src/wireguard/noise/errors.rs b/burrow/src/wireguard/noise/errors.rs new file mode 100755 index 0000000..10513ae --- /dev/null +++ b/burrow/src/wireguard/noise/errors.rs @@ -0,0 +1,23 @@ +// Copyright (c) 2019 Cloudflare, Inc. All rights reserved. +// SPDX-License-Identifier: BSD-3-Clause + +#[derive(Debug)] +pub enum WireGuardError { + DestinationBufferTooSmall, + IncorrectPacketLength, + UnexpectedPacket, + WrongPacketType, + WrongIndex, + WrongKey, + InvalidTai64nTimestamp, + WrongTai64nTimestamp, + InvalidMac, + InvalidAeadTag, + InvalidCounter, + DuplicateCounter, + InvalidPacket, + NoCurrentSession, + LockFailed, + ConnectionExpired, + UnderLoad, +} diff --git a/burrow/src/wireguard/noise/handshake.rs b/burrow/src/wireguard/noise/handshake.rs new file mode 100755 index 0000000..c672109 --- /dev/null +++ b/burrow/src/wireguard/noise/handshake.rs @@ -0,0 +1,901 @@ +// Copyright (c) 2019 Cloudflare, Inc. All rights reserved. +// SPDX-License-Identifier: BSD-3-Clause + +use std::{ + convert::TryInto, + time::{Duration, Instant, SystemTime}, +}; + +use aead::{Aead, Payload}; +use blake2::{ + digest::{FixedOutput, KeyInit}, + Blake2s256, + Blake2sMac, + Digest, +}; +use chacha20poly1305::XChaCha20Poly1305; +use rand_core::OsRng; +use ring::aead::{Aad, LessSafeKey, Nonce, UnboundKey, CHACHA20_POLY1305}; + +use super::{ + errors::WireGuardError, + session::Session, + x25519, + HandshakeInit, + HandshakeResponse, + PacketCookieReply, +}; + +pub(crate) const LABEL_MAC1: &[u8; 8] = b"mac1----"; +pub(crate) const LABEL_COOKIE: &[u8; 8] = b"cookie--"; +const KEY_LEN: usize = 32; +const TIMESTAMP_LEN: usize = 12; + +// initiator.chaining_key = HASH(CONSTRUCTION) +const INITIAL_CHAIN_KEY: [u8; KEY_LEN] = [ + 96, 226, 109, 174, 243, 39, 239, 192, 46, 195, 53, 226, 160, 37, 210, 208, 22, 235, 66, 6, 248, + 114, 119, 245, 45, 56, 209, 152, 139, 120, 205, 54, +]; + +// initiator.chaining_hash = HASH(initiator.chaining_key || IDENTIFIER) +const INITIAL_CHAIN_HASH: [u8; KEY_LEN] = [ + 34, 17, 179, 97, 8, 26, 197, 102, 105, 18, 67, 219, 69, 138, 213, 50, 45, 156, 108, 102, 34, + 147, 232, 183, 14, 225, 156, 101, 186, 7, 158, 243, +]; + +#[inline] +pub(crate) fn b2s_hash(data1: &[u8], data2: &[u8]) -> [u8; 32] { + let mut hash = Blake2s256::new(); + hash.update(data1); + hash.update(data2); + hash.finalize().into() +} + +#[inline] +/// RFC 2401 HMAC+Blake2s, not to be confused with *keyed* Blake2s +pub(crate) fn b2s_hmac(key: &[u8], data1: &[u8]) -> [u8; 32] { + use blake2::digest::Update; + type HmacBlake2s = hmac::SimpleHmac; + let mut hmac = HmacBlake2s::new_from_slice(key).unwrap(); + hmac.update(data1); + hmac.finalize_fixed().into() +} + +#[inline] +/// Like b2s_hmac, but chain data1 and data2 together +pub(crate) fn b2s_hmac2(key: &[u8], data1: &[u8], data2: &[u8]) -> [u8; 32] { + use blake2::digest::Update; + type HmacBlake2s = hmac::SimpleHmac; + let mut hmac = HmacBlake2s::new_from_slice(key).unwrap(); + hmac.update(data1); + hmac.update(data2); + hmac.finalize_fixed().into() +} + +#[inline] +pub(crate) fn b2s_keyed_mac_16(key: &[u8], data1: &[u8]) -> [u8; 16] { + let mut hmac = Blake2sMac::new_from_slice(key).unwrap(); + blake2::digest::Update::update(&mut hmac, data1); + hmac.finalize_fixed().into() +} + +#[inline] +pub(crate) fn b2s_keyed_mac_16_2(key: &[u8], data1: &[u8], data2: &[u8]) -> [u8; 16] { + let mut hmac = Blake2sMac::new_from_slice(key).unwrap(); + blake2::digest::Update::update(&mut hmac, data1); + blake2::digest::Update::update(&mut hmac, data2); + hmac.finalize_fixed().into() +} + +pub(crate) fn b2s_mac_24(key: &[u8], data1: &[u8]) -> [u8; 24] { + let mut hmac = Blake2sMac::new_from_slice(key).unwrap(); + blake2::digest::Update::update(&mut hmac, data1); + hmac.finalize_fixed().into() +} + +#[inline] +/// This wrapper involves an extra copy and MAY BE SLOWER +fn aead_chacha20_seal(ciphertext: &mut [u8], key: &[u8], counter: u64, data: &[u8], aad: &[u8]) { + let mut nonce: [u8; 12] = [0; 12]; + nonce[4..12].copy_from_slice(&counter.to_le_bytes()); + + aead_chacha20_seal_inner(ciphertext, key, nonce, data, aad) +} + +#[inline] +fn aead_chacha20_seal_inner( + ciphertext: &mut [u8], + key: &[u8], + nonce: [u8; 12], + data: &[u8], + aad: &[u8], +) { + let key = LessSafeKey::new(UnboundKey::new(&CHACHA20_POLY1305, key).unwrap()); + + ciphertext[..data.len()].copy_from_slice(data); + + let tag = key + .seal_in_place_separate_tag( + Nonce::assume_unique_for_key(nonce), + Aad::from(aad), + &mut ciphertext[..data.len()], + ) + .unwrap(); + + ciphertext[data.len()..].copy_from_slice(tag.as_ref()); +} + +#[inline] +/// This wrapper involves an extra copy and MAY BE SLOWER +fn aead_chacha20_open( + buffer: &mut [u8], + key: &[u8], + counter: u64, + data: &[u8], + aad: &[u8], +) -> Result<(), WireGuardError> { + let mut nonce: [u8; 12] = [0; 12]; + nonce[4..].copy_from_slice(&counter.to_le_bytes()); + + aead_chacha20_open_inner(buffer, key, nonce, data, aad) + .map_err(|_| WireGuardError::InvalidAeadTag)?; + Ok(()) +} + +#[inline] +fn aead_chacha20_open_inner( + buffer: &mut [u8], + key: &[u8], + nonce: [u8; 12], + data: &[u8], + aad: &[u8], +) -> Result<(), ring::error::Unspecified> { + let key = LessSafeKey::new(UnboundKey::new(&CHACHA20_POLY1305, key).unwrap()); + + let mut inner_buffer = data.to_owned(); + + let plaintext = key.open_in_place( + Nonce::assume_unique_for_key(nonce), + Aad::from(aad), + &mut inner_buffer, + )?; + + buffer.copy_from_slice(plaintext); + + Ok(()) +} + +#[derive(Debug)] +/// This struct represents a 12 byte [Tai64N](https://cr.yp.to/libtai/tai64.html) timestamp +struct Tai64N { + secs: u64, + nano: u32, +} + +#[derive(Debug)] +/// This struct computes a [Tai64N](https://cr.yp.to/libtai/tai64.html) timestamp from current system time +struct TimeStamper { + duration_at_start: Duration, + instant_at_start: Instant, +} + +impl TimeStamper { + /// Create a new TimeStamper + pub fn new() -> TimeStamper { + TimeStamper { + duration_at_start: SystemTime::now() + .duration_since(SystemTime::UNIX_EPOCH) + .unwrap(), + instant_at_start: Instant::now(), + } + } + + /// Take time reading and generate a 12 byte timestamp + pub fn stamp(&self) -> [u8; 12] { + const TAI64_BASE: u64 = (1u64 << 62) + 37; + let mut ext_stamp = [0u8; 12]; + let stamp = Instant::now().duration_since(self.instant_at_start) + self.duration_at_start; + ext_stamp[0..8].copy_from_slice(&(stamp.as_secs() + TAI64_BASE).to_be_bytes()); + ext_stamp[8..12].copy_from_slice(&stamp.subsec_nanos().to_be_bytes()); + ext_stamp + } +} + +impl Tai64N { + /// A zeroed out timestamp + fn zero() -> Tai64N { + Tai64N { secs: 0, nano: 0 } + } + + /// Parse a timestamp from a 12 byte u8 slice + fn parse(buf: &[u8; 12]) -> Result { + if buf.len() < 12 { + return Err(WireGuardError::InvalidTai64nTimestamp) + } + + let (sec_bytes, nano_bytes) = buf.split_at(std::mem::size_of::()); + let secs = u64::from_be_bytes(sec_bytes.try_into().unwrap()); + let nano = u32::from_be_bytes(nano_bytes.try_into().unwrap()); + + // WireGuard does not actually expect tai64n timestamp, just monotonically + // increasing one if secs < (1u64 << 62) || secs >= (1u64 << 63) { + // return Err(WireGuardError::InvalidTai64nTimestamp); + //}; + // if nano >= 1_000_000_000 { + // return Err(WireGuardError::InvalidTai64nTimestamp); + //} + + Ok(Tai64N { secs, nano }) + } + + /// Check if this timestamp represents a time that is chronologically after + /// the time represented by the other timestamp + pub fn after(&self, other: &Tai64N) -> bool { + (self.secs > other.secs) || ((self.secs == other.secs) && (self.nano > other.nano)) + } +} + +/// Parameters used by the noise protocol +struct NoiseParams { + /// Our static public key + static_public: x25519::PublicKey, + /// Our static private key + static_private: x25519::StaticSecret, + /// Static public key of the other party + peer_static_public: x25519::PublicKey, + /// A shared key = DH(static_private, peer_static_public) + static_shared: x25519::SharedSecret, + /// A pre-computation of HASH("mac1----", peer_static_public) for this peer + sending_mac1_key: [u8; KEY_LEN], + /// An optional preshared key + preshared_key: Option<[u8; KEY_LEN]>, +} + +impl std::fmt::Debug for NoiseParams { + fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + f.debug_struct("NoiseParams") + .field("static_public", &self.static_public) + .field("static_private", &"") + .field("peer_static_public", &self.peer_static_public) + .field("static_shared", &"") + .field("sending_mac1_key", &self.sending_mac1_key) + .field("preshared_key", &self.preshared_key) + .finish() + } +} + +struct HandshakeInitSentState { + local_index: u32, + hash: [u8; KEY_LEN], + chaining_key: [u8; KEY_LEN], + ephemeral_private: x25519::ReusableSecret, + time_sent: Instant, +} + +impl std::fmt::Debug for HandshakeInitSentState { + fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + f.debug_struct("HandshakeInitSentState") + .field("local_index", &self.local_index) + .field("hash", &self.hash) + .field("chaining_key", &self.chaining_key) + .field("ephemeral_private", &"") + .field("time_sent", &self.time_sent) + .finish() + } +} + +#[derive(Debug)] +enum HandshakeState { + /// No handshake in process + None, + /// We initiated the handshake + InitSent(HandshakeInitSentState), + /// Handshake initiated by peer + InitReceived { + hash: [u8; KEY_LEN], + chaining_key: [u8; KEY_LEN], + peer_ephemeral_public: x25519::PublicKey, + peer_index: u32, + }, + /// Handshake was established too long ago (implies no handshake is in + /// progress) + Expired, +} + +#[derive(Debug)] +pub struct Handshake { + params: NoiseParams, + /// Index of the next session + next_index: u32, + /// Allow to have two outgoing handshakes in flight, because sometimes we + /// may receive a delayed response to a handshake with bad networks + previous: HandshakeState, + /// Current handshake state + state: HandshakeState, + cookies: Cookies, + /// The timestamp of the last handshake we received + last_handshake_timestamp: Tai64N, + // TODO: make TimeStamper a singleton + stamper: TimeStamper, + pub(super) last_rtt: Option, +} + +#[derive(Default, Debug)] +struct Cookies { + last_mac1: Option<[u8; 16]>, + index: u32, + write_cookie: Option<[u8; 16]>, +} + +#[derive(Debug)] +pub struct HalfHandshake { + pub peer_index: u32, + pub peer_static_public: [u8; 32], +} + +pub fn parse_handshake_anon( + static_private: &x25519::StaticSecret, + static_public: &x25519::PublicKey, + packet: &HandshakeInit, +) -> Result { + let peer_index = packet.sender_idx; + // initiator.chaining_key = HASH(CONSTRUCTION) + let mut chaining_key = INITIAL_CHAIN_KEY; + // initiator.hash = HASH(HASH(initiator.chaining_key || IDENTIFIER) || + // responder.static_public) + let mut hash = INITIAL_CHAIN_HASH; + hash = b2s_hash(&hash, static_public.as_bytes()); + // msg.unencrypted_ephemeral = DH_PUBKEY(initiator.ephemeral_private) + let peer_ephemeral_public = x25519::PublicKey::from(*packet.unencrypted_ephemeral); + // initiator.hash = HASH(initiator.hash || msg.unencrypted_ephemeral) + hash = b2s_hash(&hash, peer_ephemeral_public.as_bytes()); + // temp = HMAC(initiator.chaining_key, msg.unencrypted_ephemeral) + // initiator.chaining_key = HMAC(temp, 0x1) + chaining_key = b2s_hmac( + &b2s_hmac(&chaining_key, peer_ephemeral_public.as_bytes()), + &[0x01], + ); + // temp = HMAC(initiator.chaining_key, DH(initiator.ephemeral_private, + // responder.static_public)) + let ephemeral_shared = static_private.diffie_hellman(&peer_ephemeral_public); + let temp = b2s_hmac(&chaining_key, &ephemeral_shared.to_bytes()); + // initiator.chaining_key = HMAC(temp, 0x1) + chaining_key = b2s_hmac(&temp, &[0x01]); + // key = HMAC(temp, initiator.chaining_key || 0x2) + let key = b2s_hmac2(&temp, &chaining_key, &[0x02]); + + let mut peer_static_public = [0u8; KEY_LEN]; + // msg.encrypted_static = AEAD(key, 0, initiator.static_public, initiator.hash) + aead_chacha20_open( + &mut peer_static_public, + &key, + 0, + packet.encrypted_static, + &hash, + )?; + + Ok(HalfHandshake { peer_index, peer_static_public }) +} + +impl NoiseParams { + /// New noise params struct from our secret key, peers public key, and + /// optional preshared key + fn new( + static_private: x25519::StaticSecret, + static_public: x25519::PublicKey, + peer_static_public: x25519::PublicKey, + preshared_key: Option<[u8; 32]>, + ) -> Result { + let static_shared = static_private.diffie_hellman(&peer_static_public); + + let initial_sending_mac_key = b2s_hash(LABEL_MAC1, peer_static_public.as_bytes()); + + Ok(NoiseParams { + static_public, + static_private, + peer_static_public, + static_shared, + sending_mac1_key: initial_sending_mac_key, + preshared_key, + }) + } + + /// Set a new private key + fn set_static_private( + &mut self, + static_private: x25519::StaticSecret, + static_public: x25519::PublicKey, + ) -> Result<(), WireGuardError> { + // Check that the public key indeed matches the private key + let check_key = x25519::PublicKey::from(&static_private); + assert_eq!(check_key.as_bytes(), static_public.as_bytes()); + + self.static_private = static_private; + self.static_public = static_public; + + self.static_shared = self.static_private.diffie_hellman(&self.peer_static_public); + Ok(()) + } +} + +impl Handshake { + pub(crate) fn new( + static_private: x25519::StaticSecret, + static_public: x25519::PublicKey, + peer_static_public: x25519::PublicKey, + global_idx: u32, + preshared_key: Option<[u8; 32]>, + ) -> Result { + let params = NoiseParams::new( + static_private, + static_public, + peer_static_public, + preshared_key, + )?; + + Ok(Handshake { + params, + next_index: global_idx, + previous: HandshakeState::None, + state: HandshakeState::None, + last_handshake_timestamp: Tai64N::zero(), + stamper: TimeStamper::new(), + cookies: Default::default(), + last_rtt: None, + }) + } + + pub(crate) fn is_in_progress(&self) -> bool { + !matches!(self.state, HandshakeState::None | HandshakeState::Expired) + } + + pub(crate) fn timer(&self) -> Option { + match self.state { + HandshakeState::InitSent(HandshakeInitSentState { time_sent, .. }) => Some(time_sent), + _ => None, + } + } + + pub(crate) fn set_expired(&mut self) { + self.previous = HandshakeState::Expired; + self.state = HandshakeState::Expired; + } + + pub(crate) fn is_expired(&self) -> bool { + matches!(self.state, HandshakeState::Expired) + } + + pub(crate) fn has_cookie(&self) -> bool { + self.cookies.write_cookie.is_some() + } + + pub(crate) fn clear_cookie(&mut self) { + self.cookies.write_cookie = None; + } + + // The index used is 24 bits for peer index, allowing for 16M active peers per + // server and 8 bits for cyclic session index + fn inc_index(&mut self) -> u32 { + let index = self.next_index; + let idx8 = index as u8; + self.next_index = (index & !0xff) | u32::from(idx8.wrapping_add(1)); + self.next_index + } + + pub(crate) fn set_static_private( + &mut self, + private_key: x25519::StaticSecret, + public_key: x25519::PublicKey, + ) -> Result<(), WireGuardError> { + self.params.set_static_private(private_key, public_key) + } + + pub(super) fn receive_handshake_initialization<'a>( + &mut self, + packet: HandshakeInit, + dst: &'a mut [u8], + ) -> Result<(&'a mut [u8], Session), WireGuardError> { + // initiator.chaining_key = HASH(CONSTRUCTION) + let mut chaining_key = INITIAL_CHAIN_KEY; + // initiator.hash = HASH(HASH(initiator.chaining_key || IDENTIFIER) || + // responder.static_public) + let mut hash = INITIAL_CHAIN_HASH; + hash = b2s_hash(&hash, self.params.static_public.as_bytes()); + // msg.sender_index = little_endian(initiator.sender_index) + let peer_index = packet.sender_idx; + // msg.unencrypted_ephemeral = DH_PUBKEY(initiator.ephemeral_private) + let peer_ephemeral_public = x25519::PublicKey::from(*packet.unencrypted_ephemeral); + // initiator.hash = HASH(initiator.hash || msg.unencrypted_ephemeral) + hash = b2s_hash(&hash, peer_ephemeral_public.as_bytes()); + // temp = HMAC(initiator.chaining_key, msg.unencrypted_ephemeral) + // initiator.chaining_key = HMAC(temp, 0x1) + chaining_key = b2s_hmac( + &b2s_hmac(&chaining_key, peer_ephemeral_public.as_bytes()), + &[0x01], + ); + // temp = HMAC(initiator.chaining_key, DH(initiator.ephemeral_private, + // responder.static_public)) + let ephemeral_shared = self + .params + .static_private + .diffie_hellman(&peer_ephemeral_public); + let temp = b2s_hmac(&chaining_key, &ephemeral_shared.to_bytes()); + // initiator.chaining_key = HMAC(temp, 0x1) + chaining_key = b2s_hmac(&temp, &[0x01]); + // key = HMAC(temp, initiator.chaining_key || 0x2) + let key = b2s_hmac2(&temp, &chaining_key, &[0x02]); + + let mut peer_static_public_decrypted = [0u8; KEY_LEN]; + // msg.encrypted_static = AEAD(key, 0, initiator.static_public, initiator.hash) + aead_chacha20_open( + &mut peer_static_public_decrypted, + &key, + 0, + packet.encrypted_static, + &hash, + )?; + + ring::constant_time::verify_slices_are_equal( + self.params.peer_static_public.as_bytes(), + &peer_static_public_decrypted, + ) + .map_err(|_| WireGuardError::WrongKey)?; + + // initiator.hash = HASH(initiator.hash || msg.encrypted_static) + hash = b2s_hash(&hash, packet.encrypted_static); + // temp = HMAC(initiator.chaining_key, DH(initiator.static_private, + // responder.static_public)) + let temp = b2s_hmac(&chaining_key, self.params.static_shared.as_bytes()); + // initiator.chaining_key = HMAC(temp, 0x1) + chaining_key = b2s_hmac(&temp, &[0x01]); + // key = HMAC(temp, initiator.chaining_key || 0x2) + let key = b2s_hmac2(&temp, &chaining_key, &[0x02]); + // msg.encrypted_timestamp = AEAD(key, 0, TAI64N(), initiator.hash) + let mut timestamp = [0u8; TIMESTAMP_LEN]; + aead_chacha20_open(&mut timestamp, &key, 0, packet.encrypted_timestamp, &hash)?; + + let timestamp = Tai64N::parse(×tamp)?; + if !timestamp.after(&self.last_handshake_timestamp) { + // Possibly a replay + return Err(WireGuardError::WrongTai64nTimestamp) + } + self.last_handshake_timestamp = timestamp; + + // initiator.hash = HASH(initiator.hash || msg.encrypted_timestamp) + hash = b2s_hash(&hash, packet.encrypted_timestamp); + + self.previous = std::mem::replace(&mut self.state, HandshakeState::InitReceived { + chaining_key, + hash, + peer_ephemeral_public, + peer_index, + }); + + self.format_handshake_response(dst) + } + + pub(super) fn receive_handshake_response( + &mut self, + packet: HandshakeResponse, + ) -> Result { + // Check if there is a handshake awaiting a response and return the correct one + let (state, is_previous) = match (&self.state, &self.previous) { + (HandshakeState::InitSent(s), _) if s.local_index == packet.receiver_idx => (s, false), + (_, HandshakeState::InitSent(s)) if s.local_index == packet.receiver_idx => (s, true), + _ => return Err(WireGuardError::UnexpectedPacket), + }; + + let peer_index = packet.sender_idx; + let local_index = state.local_index; + + let unencrypted_ephemeral = x25519::PublicKey::from(*packet.unencrypted_ephemeral); + // msg.unencrypted_ephemeral = DH_PUBKEY(responder.ephemeral_private) + // responder.hash = HASH(responder.hash || msg.unencrypted_ephemeral) + let mut hash = b2s_hash(&state.hash, unencrypted_ephemeral.as_bytes()); + // temp = HMAC(responder.chaining_key, msg.unencrypted_ephemeral) + let temp = b2s_hmac(&state.chaining_key, unencrypted_ephemeral.as_bytes()); + // responder.chaining_key = HMAC(temp, 0x1) + let mut chaining_key = b2s_hmac(&temp, &[0x01]); + // temp = HMAC(responder.chaining_key, DH(responder.ephemeral_private, + // initiator.ephemeral_public)) + let ephemeral_shared = state + .ephemeral_private + .diffie_hellman(&unencrypted_ephemeral); + let temp = b2s_hmac(&chaining_key, &ephemeral_shared.to_bytes()); + // responder.chaining_key = HMAC(temp, 0x1) + chaining_key = b2s_hmac(&temp, &[0x01]); + // temp = HMAC(responder.chaining_key, DH(responder.ephemeral_private, + // initiator.static_public)) + let temp = b2s_hmac( + &chaining_key, + &self + .params + .static_private + .diffie_hellman(&unencrypted_ephemeral) + .to_bytes(), + ); + // responder.chaining_key = HMAC(temp, 0x1) + chaining_key = b2s_hmac(&temp, &[0x01]); + // temp = HMAC(responder.chaining_key, preshared_key) + let temp = b2s_hmac( + &chaining_key, + &self.params.preshared_key.unwrap_or([0u8; 32])[..], + ); + // responder.chaining_key = HMAC(temp, 0x1) + chaining_key = b2s_hmac(&temp, &[0x01]); + // temp2 = HMAC(temp, responder.chaining_key || 0x2) + let temp2 = b2s_hmac2(&temp, &chaining_key, &[0x02]); + // key = HMAC(temp, temp2 || 0x3) + let key = b2s_hmac2(&temp, &temp2, &[0x03]); + // responder.hash = HASH(responder.hash || temp2) + hash = b2s_hash(&hash, &temp2); + // msg.encrypted_nothing = AEAD(key, 0, [empty], responder.hash) + aead_chacha20_open(&mut [], &key, 0, packet.encrypted_nothing, &hash)?; + + // responder.hash = HASH(responder.hash || msg.encrypted_nothing) + // hash = b2s_hash(hash, buf[ENC_NOTHING_OFF..ENC_NOTHING_OFF + + // ENC_NOTHING_SZ]); + + // Derive keys + // temp1 = HMAC(initiator.chaining_key, [empty]) + // temp2 = HMAC(temp1, 0x1) + // temp3 = HMAC(temp1, temp2 || 0x2) + // initiator.sending_key = temp2 + // initiator.receiving_key = temp3 + // initiator.sending_key_counter = 0 + // initiator.receiving_key_counter = 0 + let temp1 = b2s_hmac(&chaining_key, &[]); + let temp2 = b2s_hmac(&temp1, &[0x01]); + let temp3 = b2s_hmac2(&temp1, &temp2, &[0x02]); + + let rtt_time = Instant::now().duration_since(state.time_sent); + self.last_rtt = Some(rtt_time.as_millis() as u32); + + if is_previous { + self.previous = HandshakeState::None; + } else { + self.state = HandshakeState::None; + } + Ok(Session::new(local_index, peer_index, temp3, temp2)) + } + + pub(super) fn receive_cookie_reply( + &mut self, + packet: PacketCookieReply, + ) -> Result<(), WireGuardError> { + let mac1 = match self.cookies.last_mac1 { + Some(mac) => mac, + None => return Err(WireGuardError::UnexpectedPacket), + }; + + let local_index = self.cookies.index; + if packet.receiver_idx != local_index { + return Err(WireGuardError::WrongIndex) + } + // msg.encrypted_cookie = XAEAD(HASH(LABEL_COOKIE || responder.static_public), + // msg.nonce, cookie, last_received_msg.mac1) + let key = b2s_hash(LABEL_COOKIE, self.params.peer_static_public.as_bytes()); // TODO: pre-compute + + let payload = Payload { + aad: &mac1[0..16], + msg: packet.encrypted_cookie, + }; + let plaintext = XChaCha20Poly1305::new_from_slice(&key) + .unwrap() + .decrypt(packet.nonce.into(), payload) + .map_err(|_| WireGuardError::InvalidAeadTag)?; + + let cookie = plaintext + .try_into() + .map_err(|_| WireGuardError::InvalidPacket)?; + self.cookies.write_cookie = Some(cookie); + Ok(()) + } + + // Compute and append mac1 and mac2 to a handshake message + fn append_mac1_and_mac2<'a>( + &mut self, + local_index: u32, + dst: &'a mut [u8], + ) -> Result<&'a mut [u8], WireGuardError> { + let mac1_off = dst.len() - 32; + let mac2_off = dst.len() - 16; + + // msg.mac1 = MAC(HASH(LABEL_MAC1 || responder.static_public), + // msg[0:offsetof(msg.mac1)]) + let msg_mac1 = b2s_keyed_mac_16(&self.params.sending_mac1_key, &dst[..mac1_off]); + + dst[mac1_off..mac2_off].copy_from_slice(&msg_mac1[..]); + + // msg.mac2 = MAC(initiator.last_received_cookie, msg[0:offsetof(msg.mac2)]) + let msg_mac2: [u8; 16] = if let Some(cookie) = self.cookies.write_cookie { + b2s_keyed_mac_16(&cookie, &dst[..mac2_off]) + } else { + [0u8; 16] + }; + + dst[mac2_off..].copy_from_slice(&msg_mac2[..]); + + self.cookies.index = local_index; + self.cookies.last_mac1 = Some(msg_mac1); + Ok(dst) + } + + pub(super) fn format_handshake_initiation<'a>( + &mut self, + dst: &'a mut [u8], + ) -> Result<&'a mut [u8], WireGuardError> { + if dst.len() < super::HANDSHAKE_INIT_SZ { + return Err(WireGuardError::DestinationBufferTooSmall) + } + + let (message_type, rest) = dst.split_at_mut(4); + let (sender_index, rest) = rest.split_at_mut(4); + let (unencrypted_ephemeral, rest) = rest.split_at_mut(32); + let (encrypted_static, rest) = rest.split_at_mut(32 + 16); + let (encrypted_timestamp, _) = rest.split_at_mut(12 + 16); + + let local_index = self.inc_index(); + + // initiator.chaining_key = HASH(CONSTRUCTION) + let mut chaining_key = INITIAL_CHAIN_KEY; + // initiator.hash = HASH(HASH(initiator.chaining_key || IDENTIFIER) || + // responder.static_public) + let mut hash = INITIAL_CHAIN_HASH; + hash = b2s_hash(&hash, self.params.peer_static_public.as_bytes()); + // initiator.ephemeral_private = DH_GENERATE() + let ephemeral_private = x25519::ReusableSecret::random_from_rng(OsRng); + // msg.message_type = 1 + // msg.reserved_zero = { 0, 0, 0 } + message_type.copy_from_slice(&super::HANDSHAKE_INIT.to_le_bytes()); + // msg.sender_index = little_endian(initiator.sender_index) + sender_index.copy_from_slice(&local_index.to_le_bytes()); + // msg.unencrypted_ephemeral = DH_PUBKEY(initiator.ephemeral_private) + unencrypted_ephemeral + .copy_from_slice(x25519::PublicKey::from(&ephemeral_private).as_bytes()); + // initiator.hash = HASH(initiator.hash || msg.unencrypted_ephemeral) + hash = b2s_hash(&hash, unencrypted_ephemeral); + // temp = HMAC(initiator.chaining_key, msg.unencrypted_ephemeral) + // initiator.chaining_key = HMAC(temp, 0x1) + chaining_key = b2s_hmac(&b2s_hmac(&chaining_key, unencrypted_ephemeral), &[0x01]); + // temp = HMAC(initiator.chaining_key, DH(initiator.ephemeral_private, + // responder.static_public)) + let ephemeral_shared = ephemeral_private.diffie_hellman(&self.params.peer_static_public); + let temp = b2s_hmac(&chaining_key, &ephemeral_shared.to_bytes()); + // initiator.chaining_key = HMAC(temp, 0x1) + chaining_key = b2s_hmac(&temp, &[0x01]); + // key = HMAC(temp, initiator.chaining_key || 0x2) + let key = b2s_hmac2(&temp, &chaining_key, &[0x02]); + // msg.encrypted_static = AEAD(key, 0, initiator.static_public, initiator.hash) + aead_chacha20_seal( + encrypted_static, + &key, + 0, + self.params.static_public.as_bytes(), + &hash, + ); + // initiator.hash = HASH(initiator.hash || msg.encrypted_static) + hash = b2s_hash(&hash, encrypted_static); + // temp = HMAC(initiator.chaining_key, DH(initiator.static_private, + // responder.static_public)) + let temp = b2s_hmac(&chaining_key, self.params.static_shared.as_bytes()); + // initiator.chaining_key = HMAC(temp, 0x1) + chaining_key = b2s_hmac(&temp, &[0x01]); + // key = HMAC(temp, initiator.chaining_key || 0x2) + let key = b2s_hmac2(&temp, &chaining_key, &[0x02]); + // msg.encrypted_timestamp = AEAD(key, 0, TAI64N(), initiator.hash) + let timestamp = self.stamper.stamp(); + aead_chacha20_seal(encrypted_timestamp, &key, 0, ×tamp, &hash); + // initiator.hash = HASH(initiator.hash || msg.encrypted_timestamp) + hash = b2s_hash(&hash, encrypted_timestamp); + + let time_now = Instant::now(); + self.previous = std::mem::replace( + &mut self.state, + HandshakeState::InitSent(HandshakeInitSentState { + local_index, + chaining_key, + hash, + ephemeral_private, + time_sent: time_now, + }), + ); + + self.append_mac1_and_mac2(local_index, &mut dst[..super::HANDSHAKE_INIT_SZ]) + } + + fn format_handshake_response<'a>( + &mut self, + dst: &'a mut [u8], + ) -> Result<(&'a mut [u8], Session), WireGuardError> { + if dst.len() < super::HANDSHAKE_RESP_SZ { + return Err(WireGuardError::DestinationBufferTooSmall) + } + + let state = std::mem::replace(&mut self.state, HandshakeState::None); + let (mut chaining_key, mut hash, peer_ephemeral_public, peer_index) = match state { + HandshakeState::InitReceived { + chaining_key, + hash, + peer_ephemeral_public, + peer_index, + } => (chaining_key, hash, peer_ephemeral_public, peer_index), + _ => { + panic!("Unexpected attempt to call send_handshake_response"); + } + }; + + let (message_type, rest) = dst.split_at_mut(4); + let (sender_index, rest) = rest.split_at_mut(4); + let (receiver_index, rest) = rest.split_at_mut(4); + let (unencrypted_ephemeral, rest) = rest.split_at_mut(32); + let (encrypted_nothing, _) = rest.split_at_mut(16); + + // responder.ephemeral_private = DH_GENERATE() + let ephemeral_private = x25519::ReusableSecret::random_from_rng(OsRng); + let local_index = self.inc_index(); + // msg.message_type = 2 + // msg.reserved_zero = { 0, 0, 0 } + message_type.copy_from_slice(&super::HANDSHAKE_RESP.to_le_bytes()); + // msg.sender_index = little_endian(responder.sender_index) + sender_index.copy_from_slice(&local_index.to_le_bytes()); + // msg.receiver_index = little_endian(initiator.sender_index) + receiver_index.copy_from_slice(&peer_index.to_le_bytes()); + // msg.unencrypted_ephemeral = DH_PUBKEY(initiator.ephemeral_private) + unencrypted_ephemeral + .copy_from_slice(x25519::PublicKey::from(&ephemeral_private).as_bytes()); + // responder.hash = HASH(responder.hash || msg.unencrypted_ephemeral) + hash = b2s_hash(&hash, unencrypted_ephemeral); + // temp = HMAC(responder.chaining_key, msg.unencrypted_ephemeral) + let temp = b2s_hmac(&chaining_key, unencrypted_ephemeral); + // responder.chaining_key = HMAC(temp, 0x1) + chaining_key = b2s_hmac(&temp, &[0x01]); + // temp = HMAC(responder.chaining_key, DH(responder.ephemeral_private, + // initiator.ephemeral_public)) + let ephemeral_shared = ephemeral_private.diffie_hellman(&peer_ephemeral_public); + let temp = b2s_hmac(&chaining_key, &ephemeral_shared.to_bytes()); + // responder.chaining_key = HMAC(temp, 0x1) + chaining_key = b2s_hmac(&temp, &[0x01]); + // temp = HMAC(responder.chaining_key, DH(responder.ephemeral_private, + // initiator.static_public)) + let temp = b2s_hmac( + &chaining_key, + &ephemeral_private + .diffie_hellman(&self.params.peer_static_public) + .to_bytes(), + ); + // responder.chaining_key = HMAC(temp, 0x1) + chaining_key = b2s_hmac(&temp, &[0x01]); + // temp = HMAC(responder.chaining_key, preshared_key) + let temp = b2s_hmac( + &chaining_key, + &self.params.preshared_key.unwrap_or([0u8; 32])[..], + ); + // responder.chaining_key = HMAC(temp, 0x1) + chaining_key = b2s_hmac(&temp, &[0x01]); + // temp2 = HMAC(temp, responder.chaining_key || 0x2) + let temp2 = b2s_hmac2(&temp, &chaining_key, &[0x02]); + // key = HMAC(temp, temp2 || 0x3) + let key = b2s_hmac2(&temp, &temp2, &[0x03]); + // responder.hash = HASH(responder.hash || temp2) + hash = b2s_hash(&hash, &temp2); + // msg.encrypted_nothing = AEAD(key, 0, [empty], responder.hash) + aead_chacha20_seal(encrypted_nothing, &key, 0, &[], &hash); + + // Derive keys + // temp1 = HMAC(initiator.chaining_key, [empty]) + // temp2 = HMAC(temp1, 0x1) + // temp3 = HMAC(temp1, temp2 || 0x2) + // initiator.sending_key = temp2 + // initiator.receiving_key = temp3 + // initiator.sending_key_counter = 0 + // initiator.receiving_key_counter = 0 + let temp1 = b2s_hmac(&chaining_key, &[]); + let temp2 = b2s_hmac(&temp1, &[0x01]); + let temp3 = b2s_hmac2(&temp1, &temp2, &[0x02]); + + let dst = self.append_mac1_and_mac2(local_index, &mut dst[..super::HANDSHAKE_RESP_SZ])?; + + Ok((dst, Session::new(local_index, peer_index, temp2, temp3))) + } +} diff --git a/burrow/src/wireguard/noise/mod.rs b/burrow/src/wireguard/noise/mod.rs new file mode 100755 index 0000000..3e8a6f0 --- /dev/null +++ b/burrow/src/wireguard/noise/mod.rs @@ -0,0 +1,609 @@ +// Copyright (c) 2019 Cloudflare, Inc. All rights reserved. +// SPDX-License-Identifier: BSD-3-Clause + +pub mod errors; +pub mod handshake; +pub mod rate_limiter; + +mod session; +mod timers; + +use std::{ + collections::VecDeque, + convert::{TryFrom, TryInto}, + net::{IpAddr, Ipv4Addr, Ipv6Addr}, + sync::Arc, + time::Duration, +}; + +use errors::WireGuardError; +use handshake::Handshake; +use rate_limiter::RateLimiter; +use timers::{TimerName, Timers}; + +/// The default value to use for rate limiting, when no other rate limiter is +/// defined +const PEER_HANDSHAKE_RATE_LIMIT: u64 = 10; + +const IPV4_MIN_HEADER_SIZE: usize = 20; +const IPV4_LEN_OFF: usize = 2; +const IPV4_SRC_IP_OFF: usize = 12; +const IPV4_DST_IP_OFF: usize = 16; +const IPV4_IP_SZ: usize = 4; + +const IPV6_MIN_HEADER_SIZE: usize = 40; +const IPV6_LEN_OFF: usize = 4; +const IPV6_SRC_IP_OFF: usize = 8; +const IPV6_DST_IP_OFF: usize = 24; +const IPV6_IP_SZ: usize = 16; + +const IP_LEN_SZ: usize = 2; + +const MAX_QUEUE_DEPTH: usize = 256; +/// number of sessions in the ring, better keep a PoT +const N_SESSIONS: usize = 8; + +pub mod x25519 { + pub use x25519_dalek::{ + EphemeralSecret, + PublicKey, + ReusableSecret, + SharedSecret, + StaticSecret, + }; +} + +#[derive(Debug)] +pub enum TunnResult<'a> { + Done, + Err(WireGuardError), + WriteToNetwork(&'a mut [u8]), + WriteToTunnelV4(&'a mut [u8], Ipv4Addr), + WriteToTunnelV6(&'a mut [u8], Ipv6Addr), +} + +impl<'a> From for TunnResult<'a> { + fn from(err: WireGuardError) -> TunnResult<'a> { + TunnResult::Err(err) + } +} + +/// Tunnel represents a point-to-point WireGuard connection +#[derive(Debug)] +pub struct Tunnel { + /// The handshake currently in progress + handshake: handshake::Handshake, + /// The N_SESSIONS most recent sessions, index is session id modulo + /// N_SESSIONS + sessions: [Option; N_SESSIONS], + /// Index of most recently used session + current: usize, + /// Queue to store blocked packets + packet_queue: VecDeque>, + /// Keeps tabs on the expiring timers + timers: timers::Timers, + tx_bytes: usize, + rx_bytes: usize, + rate_limiter: Arc, +} + +type MessageType = u32; +const HANDSHAKE_INIT: MessageType = 1; +const HANDSHAKE_RESP: MessageType = 2; +const COOKIE_REPLY: MessageType = 3; +const DATA: MessageType = 4; + +const HANDSHAKE_INIT_SZ: usize = 148; +const HANDSHAKE_RESP_SZ: usize = 92; +const COOKIE_REPLY_SZ: usize = 64; +const DATA_OVERHEAD_SZ: usize = 32; + +#[derive(Debug)] +pub struct HandshakeInit<'a> { + sender_idx: u32, + unencrypted_ephemeral: &'a [u8; 32], + encrypted_static: &'a [u8], + encrypted_timestamp: &'a [u8], +} + +#[derive(Debug)] +pub struct HandshakeResponse<'a> { + sender_idx: u32, + pub receiver_idx: u32, + unencrypted_ephemeral: &'a [u8; 32], + encrypted_nothing: &'a [u8], +} + +#[derive(Debug)] +pub struct PacketCookieReply<'a> { + pub receiver_idx: u32, + nonce: &'a [u8], + encrypted_cookie: &'a [u8], +} + +#[derive(Debug)] +pub struct PacketData<'a> { + pub receiver_idx: u32, + counter: u64, + encrypted_encapsulated_packet: &'a [u8], +} + +/// Describes a packet from network +#[derive(Debug)] +pub enum Packet<'a> { + HandshakeInit(HandshakeInit<'a>), + HandshakeResponse(HandshakeResponse<'a>), + PacketCookieReply(PacketCookieReply<'a>), + PacketData(PacketData<'a>), +} + +impl Tunnel { + #[inline(always)] + pub fn parse_incoming_packet(src: &[u8]) -> Result { + if src.len() < 4 { + return Err(WireGuardError::InvalidPacket) + } + + // Checks the type, as well as the reserved zero fields + let packet_type = u32::from_le_bytes(src[0..4].try_into().unwrap()); + + Ok(match (packet_type, src.len()) { + (HANDSHAKE_INIT, HANDSHAKE_INIT_SZ) => Packet::HandshakeInit(HandshakeInit { + sender_idx: u32::from_le_bytes(src[4..8].try_into().unwrap()), + unencrypted_ephemeral: <&[u8; 32] as TryFrom<&[u8]>>::try_from(&src[8..40]) + .expect("length already checked above"), + encrypted_static: &src[40..88], + encrypted_timestamp: &src[88..116], + }), + (HANDSHAKE_RESP, HANDSHAKE_RESP_SZ) => Packet::HandshakeResponse(HandshakeResponse { + sender_idx: u32::from_le_bytes(src[4..8].try_into().unwrap()), + receiver_idx: u32::from_le_bytes(src[8..12].try_into().unwrap()), + unencrypted_ephemeral: <&[u8; 32] as TryFrom<&[u8]>>::try_from(&src[12..44]) + .expect("length already checked above"), + encrypted_nothing: &src[44..60], + }), + (COOKIE_REPLY, COOKIE_REPLY_SZ) => Packet::PacketCookieReply(PacketCookieReply { + receiver_idx: u32::from_le_bytes(src[4..8].try_into().unwrap()), + nonce: &src[8..32], + encrypted_cookie: &src[32..64], + }), + (DATA, DATA_OVERHEAD_SZ..=std::usize::MAX) => Packet::PacketData(PacketData { + receiver_idx: u32::from_le_bytes(src[4..8].try_into().unwrap()), + counter: u64::from_le_bytes(src[8..16].try_into().unwrap()), + encrypted_encapsulated_packet: &src[16..], + }), + _ => return Err(WireGuardError::InvalidPacket), + }) + } + + pub fn is_expired(&self) -> bool { + self.handshake.is_expired() + } + + pub fn dst_address(packet: &[u8]) -> Option { + if packet.is_empty() { + return None + } + + match packet[0] >> 4 { + 4 if packet.len() >= IPV4_MIN_HEADER_SIZE => { + let addr_bytes: [u8; IPV4_IP_SZ] = packet + [IPV4_DST_IP_OFF..IPV4_DST_IP_OFF + IPV4_IP_SZ] + .try_into() + .unwrap(); + Some(IpAddr::from(addr_bytes)) + } + 6 if packet.len() >= IPV6_MIN_HEADER_SIZE => { + let addr_bytes: [u8; IPV6_IP_SZ] = packet + [IPV6_DST_IP_OFF..IPV6_DST_IP_OFF + IPV6_IP_SZ] + .try_into() + .unwrap(); + Some(IpAddr::from(addr_bytes)) + } + _ => None, + } + } + + /// Create a new tunnel using own private key and the peer public key + pub fn new( + static_private: x25519::StaticSecret, + peer_static_public: x25519::PublicKey, + preshared_key: Option<[u8; 32]>, + persistent_keepalive: Option, + index: u32, + rate_limiter: Option>, + ) -> Result { + let static_public = x25519::PublicKey::from(&static_private); + + let tunn = Tunnel { + handshake: Handshake::new( + static_private, + static_public, + peer_static_public, + index << 8, + preshared_key, + ) + .map_err(|_| "Invalid parameters")?, + sessions: Default::default(), + current: Default::default(), + tx_bytes: Default::default(), + rx_bytes: Default::default(), + + packet_queue: VecDeque::new(), + timers: Timers::new(persistent_keepalive, rate_limiter.is_none()), + + rate_limiter: rate_limiter.unwrap_or_else(|| { + Arc::new(RateLimiter::new(&static_public, PEER_HANDSHAKE_RATE_LIMIT)) + }), + }; + + Ok(tunn) + } + + /// Update the private key and clear existing sessions + pub fn set_static_private( + &mut self, + static_private: x25519::StaticSecret, + static_public: x25519::PublicKey, + rate_limiter: Option>, + ) -> Result<(), WireGuardError> { + self.timers.should_reset_rr = rate_limiter.is_none(); + self.rate_limiter = rate_limiter.unwrap_or_else(|| { + Arc::new(RateLimiter::new(&static_public, PEER_HANDSHAKE_RATE_LIMIT)) + }); + self.handshake + .set_static_private(static_private, static_public)?; + for s in &mut self.sessions { + *s = None; + } + Ok(()) + } + + /// Encapsulate a single packet from the tunnel interface. + /// Returns TunnResult. + /// + /// # Panics + /// Panics if dst buffer is too small. + /// Size of dst should be at least src.len() + 32, and no less than 148 + /// bytes. + pub fn encapsulate<'a>(&mut self, src: &[u8], dst: &'a mut [u8]) -> TunnResult<'a> { + let current = self.current; + if let Some(ref session) = self.sessions[current % N_SESSIONS] { + // Send the packet using an established session + let packet = session.format_packet_data(src, dst); + self.timer_tick(TimerName::TimeLastPacketSent); + // Exclude Keepalive packets from timer update. + if !src.is_empty() { + self.timer_tick(TimerName::TimeLastDataPacketSent); + } + self.tx_bytes += src.len(); + return TunnResult::WriteToNetwork(packet) + } + + // If there is no session, queue the packet for future retry + self.queue_packet(src); + // Initiate a new handshake if none is in progress + self.format_handshake_initiation(dst, false) + } + + /// Receives a UDP datagram from the network and parses it. + /// Returns TunnResult. + /// + /// If the result is of type TunnResult::WriteToNetwork, should repeat the + /// call with empty datagram, until TunnResult::Done is returned. If + /// batch processing packets, it is OK to defer until last + /// packet is processed. + pub fn decapsulate<'a>( + &mut self, + src_addr: Option, + datagram: &[u8], + dst: &'a mut [u8], + ) -> TunnResult<'a> { + if datagram.is_empty() { + // Indicates a repeated call + return self.send_queued_packet(dst) + } + + let mut cookie = [0u8; COOKIE_REPLY_SZ]; + let packet = match self + .rate_limiter + .verify_packet(src_addr, datagram, &mut cookie) + { + Ok(packet) => packet, + Err(TunnResult::WriteToNetwork(cookie)) => { + dst[..cookie.len()].copy_from_slice(cookie); + return TunnResult::WriteToNetwork(&mut dst[..cookie.len()]) + } + Err(TunnResult::Err(e)) => return TunnResult::Err(e), + _ => unreachable!(), + }; + + self.handle_verified_packet(packet, dst) + } + + pub(crate) fn handle_verified_packet<'a>( + &mut self, + packet: Packet, + dst: &'a mut [u8], + ) -> TunnResult<'a> { + match packet { + Packet::HandshakeInit(p) => self.handle_handshake_init(p, dst), + Packet::HandshakeResponse(p) => self.handle_handshake_response(p, dst), + Packet::PacketCookieReply(p) => self.handle_cookie_reply(p), + Packet::PacketData(p) => self.handle_data(p, dst), + } + .unwrap_or_else(TunnResult::from) + } + + fn handle_handshake_init<'a>( + &mut self, + p: HandshakeInit, + dst: &'a mut [u8], + ) -> Result, WireGuardError> { + tracing::debug!( + message = "Received handshake_initiation", + remote_idx = p.sender_idx + ); + + let (packet, session) = self.handshake.receive_handshake_initialization(p, dst)?; + + // Store new session in ring buffer + let index = session.local_index(); + self.sessions[index % N_SESSIONS] = Some(session); + + self.timer_tick(TimerName::TimeLastPacketReceived); + self.timer_tick(TimerName::TimeLastPacketSent); + self.timer_tick_session_established(false, index); // New session established, we are not the initiator + + tracing::debug!(message = "Sending handshake_response", local_idx = index); + + Ok(TunnResult::WriteToNetwork(packet)) + } + + fn handle_handshake_response<'a>( + &mut self, + p: HandshakeResponse, + dst: &'a mut [u8], + ) -> Result, WireGuardError> { + tracing::debug!( + message = "Received handshake_response", + local_idx = p.receiver_idx, + remote_idx = p.sender_idx + ); + + let session = self.handshake.receive_handshake_response(p)?; + + let keepalive_packet = session.format_packet_data(&[], dst); + // Store new session in ring buffer + let l_idx = session.local_index(); + let index = l_idx % N_SESSIONS; + self.sessions[index] = Some(session); + + self.timer_tick(TimerName::TimeLastPacketReceived); + self.timer_tick_session_established(true, index); // New session established, we are the initiator + self.set_current_session(l_idx); + + tracing::debug!("Sending keepalive"); + + Ok(TunnResult::WriteToNetwork(keepalive_packet)) // Send a keepalive as + // a response + } + + fn handle_cookie_reply<'a>( + &mut self, + p: PacketCookieReply, + ) -> Result, WireGuardError> { + tracing::debug!( + message = "Received cookie_reply", + local_idx = p.receiver_idx + ); + + self.handshake.receive_cookie_reply(p)?; + self.timer_tick(TimerName::TimeLastPacketReceived); + self.timer_tick(TimerName::TimeCookieReceived); + + tracing::debug!("Did set cookie"); + + Ok(TunnResult::Done) + } + + /// Update the index of the currently used session, if needed + fn set_current_session(&mut self, new_idx: usize) { + let cur_idx = self.current; + if cur_idx == new_idx { + // There is nothing to do, already using this session, this is the common case + return + } + if self.sessions[cur_idx % N_SESSIONS].is_none() + || self.timers.session_timers[new_idx % N_SESSIONS] + >= self.timers.session_timers[cur_idx % N_SESSIONS] + { + self.current = new_idx; + tracing::debug!(message = "New session", session = new_idx); + } + } + + /// Decrypts a data packet, and stores the decapsulated packet in dst. + fn handle_data<'a>( + &mut self, + packet: PacketData, + dst: &'a mut [u8], + ) -> Result, WireGuardError> { + let r_idx = packet.receiver_idx as usize; + let idx = r_idx % N_SESSIONS; + + // Get the (probably) right session + let decapsulated_packet = { + let session = self.sessions[idx].as_ref(); + let session = session.ok_or_else(|| { + tracing::trace!(message = "No current session available", remote_idx = r_idx); + WireGuardError::NoCurrentSession + })?; + session.receive_packet_data(packet, dst)? + }; + + self.set_current_session(r_idx); + + self.timer_tick(TimerName::TimeLastPacketReceived); + + Ok(self.validate_decapsulated_packet(decapsulated_packet)) + } + + /// Formats a new handshake initiation message and store it in dst. If + /// force_resend is true will send a new handshake, even if a handshake + /// is already in progress (for example when a handshake times out) + pub fn format_handshake_initiation<'a>( + &mut self, + dst: &'a mut [u8], + force_resend: bool, + ) -> TunnResult<'a> { + if self.handshake.is_in_progress() && !force_resend { + return TunnResult::Done + } + + if self.handshake.is_expired() { + self.timers.clear(); + } + + let starting_new_handshake = !self.handshake.is_in_progress(); + + match self.handshake.format_handshake_initiation(dst) { + Ok(packet) => { + tracing::debug!("Sending handshake_initiation"); + + if starting_new_handshake { + self.timer_tick(TimerName::TimeLastHandshakeStarted); + } + self.timer_tick(TimerName::TimeLastPacketSent); + TunnResult::WriteToNetwork(packet) + } + Err(e) => TunnResult::Err(e), + } + } + + /// Check if an IP packet is v4 or v6, truncate to the length indicated by + /// the length field Returns the truncated packet and the source IP as + /// TunnResult + fn validate_decapsulated_packet<'a>(&mut self, packet: &'a mut [u8]) -> TunnResult<'a> { + let (computed_len, src_ip_address) = match packet.len() { + 0 => return TunnResult::Done, // This is keepalive, and not an error + _ if packet[0] >> 4 == 4 && packet.len() >= IPV4_MIN_HEADER_SIZE => { + let len_bytes: [u8; IP_LEN_SZ] = packet[IPV4_LEN_OFF..IPV4_LEN_OFF + IP_LEN_SZ] + .try_into() + .unwrap(); + let addr_bytes: [u8; IPV4_IP_SZ] = packet + [IPV4_SRC_IP_OFF..IPV4_SRC_IP_OFF + IPV4_IP_SZ] + .try_into() + .unwrap(); + ( + u16::from_be_bytes(len_bytes) as usize, + IpAddr::from(addr_bytes), + ) + } + _ if packet[0] >> 4 == 6 && packet.len() >= IPV6_MIN_HEADER_SIZE => { + let len_bytes: [u8; IP_LEN_SZ] = packet[IPV6_LEN_OFF..IPV6_LEN_OFF + IP_LEN_SZ] + .try_into() + .unwrap(); + let addr_bytes: [u8; IPV6_IP_SZ] = packet + [IPV6_SRC_IP_OFF..IPV6_SRC_IP_OFF + IPV6_IP_SZ] + .try_into() + .unwrap(); + ( + u16::from_be_bytes(len_bytes) as usize + IPV6_MIN_HEADER_SIZE, + IpAddr::from(addr_bytes), + ) + } + _ => return TunnResult::Err(WireGuardError::InvalidPacket), + }; + + if computed_len > packet.len() { + return TunnResult::Err(WireGuardError::InvalidPacket) + } + + self.timer_tick(TimerName::TimeLastDataPacketReceived); + self.rx_bytes += computed_len; + + match src_ip_address { + IpAddr::V4(addr) => TunnResult::WriteToTunnelV4(&mut packet[..computed_len], addr), + IpAddr::V6(addr) => TunnResult::WriteToTunnelV6(&mut packet[..computed_len], addr), + } + } + + /// Get a packet from the queue, and try to encapsulate it + fn send_queued_packet<'a>(&mut self, dst: &'a mut [u8]) -> TunnResult<'a> { + if let Some(packet) = self.dequeue_packet() { + match self.encapsulate(&packet, dst) { + TunnResult::Err(_) => { + // On error, return packet to the queue + self.requeue_packet(packet); + } + r => return r, + } + } + TunnResult::Done + } + + /// Push packet to the back of the queue + fn queue_packet(&mut self, packet: &[u8]) { + if self.packet_queue.len() < MAX_QUEUE_DEPTH { + // Drop if too many are already in queue + self.packet_queue.push_back(packet.to_vec()); + } + } + + /// Push packet to the front of the queue + fn requeue_packet(&mut self, packet: Vec) { + if self.packet_queue.len() < MAX_QUEUE_DEPTH { + // Drop if too many are already in queue + self.packet_queue.push_front(packet); + } + } + + fn dequeue_packet(&mut self) -> Option> { + self.packet_queue.pop_front() + } + + fn estimate_loss(&self) -> f32 { + let session_idx = self.current; + + let mut weight = 9.0; + let mut cur_avg = 0.0; + let mut total_weight = 0.0; + + for i in 0..N_SESSIONS { + if let Some(ref session) = self.sessions[(session_idx.wrapping_sub(i)) % N_SESSIONS] { + let (expected, received) = session.current_packet_cnt(); + + let loss = if expected == 0 { + 0.0 + } else { + 1.0 - received as f32 / expected as f32 + }; + + cur_avg += loss * weight; + total_weight += weight; + weight /= 3.0; + } + } + + if total_weight == 0.0 { + 0.0 + } else { + cur_avg / total_weight + } + } + + /// Return stats from the tunnel: + /// * Time since last handshake in seconds + /// * Data bytes sent + /// * Data bytes received + pub fn stats(&self) -> (Option, usize, usize, f32, Option) { + let time = self.time_since_last_handshake(); + let tx_bytes = self.tx_bytes; + let rx_bytes = self.rx_bytes; + let loss = self.estimate_loss(); + let rtt = self.handshake.last_rtt; + + (time, tx_bytes, rx_bytes, loss, rtt) + } +} diff --git a/burrow/src/wireguard/noise/rate_limiter.rs b/burrow/src/wireguard/noise/rate_limiter.rs new file mode 100755 index 0000000..8266fe4 --- /dev/null +++ b/burrow/src/wireguard/noise/rate_limiter.rs @@ -0,0 +1,209 @@ +use std::{ + net::IpAddr, + sync::atomic::{AtomicU64, Ordering}, + time::Instant, +}; + +use aead::{generic_array::GenericArray, AeadInPlace, KeyInit}; +use chacha20poly1305::{Key, XChaCha20Poly1305}; +use parking_lot::Mutex; +use rand_core::{OsRng, RngCore}; +use ring::constant_time::verify_slices_are_equal; + +use super::{ + handshake::{ + b2s_hash, + b2s_keyed_mac_16, + b2s_keyed_mac_16_2, + b2s_mac_24, + LABEL_COOKIE, + LABEL_MAC1, + }, + HandshakeInit, + HandshakeResponse, + Packet, + TunnResult, + Tunnel, + WireGuardError, +}; + +const COOKIE_REFRESH: u64 = 128; // Use 128 and not 120 so the compiler can optimize out the division +const COOKIE_SIZE: usize = 16; +const COOKIE_NONCE_SIZE: usize = 24; + +/// How often should reset count in seconds +const RESET_PERIOD: u64 = 1; + +type Cookie = [u8; COOKIE_SIZE]; + +/// There are two places where WireGuard requires "randomness" for cookies +/// * The 24 byte nonce in the cookie massage - here the only goal is to avoid +/// nonce reuse +/// * A secret value that changes every two minutes +/// Because the main goal of the cookie is simply for a party to prove ownership +/// of an IP address we can relax the randomness definition a bit, in order to +/// avoid locking, because using less resources is the main goal of any DoS +/// prevention mechanism. In order to avoid locking and calls to rand we derive +/// pseudo random values using the AEAD and some counters. +#[derive(Debug)] +pub struct RateLimiter { + /// The key we use to derive the nonce + nonce_key: [u8; 32], + /// The key we use to derive the cookie + secret_key: [u8; 16], + start_time: Instant, + /// A single 64 bit counter (should suffice for many years) + nonce_ctr: AtomicU64, + mac1_key: [u8; 32], + cookie_key: Key, + limit: u64, + /// The counter since last reset + count: AtomicU64, + /// The time last reset was performed on this rate limiter + last_reset: Mutex, +} + +impl RateLimiter { + pub fn new(public_key: &super::x25519::PublicKey, limit: u64) -> Self { + let mut secret_key = [0u8; 16]; + OsRng.fill_bytes(&mut secret_key); + RateLimiter { + nonce_key: Self::rand_bytes(), + secret_key, + start_time: Instant::now(), + nonce_ctr: AtomicU64::new(0), + mac1_key: b2s_hash(LABEL_MAC1, public_key.as_bytes()), + cookie_key: b2s_hash(LABEL_COOKIE, public_key.as_bytes()).into(), + limit, + count: AtomicU64::new(0), + last_reset: Mutex::new(Instant::now()), + } + } + + fn rand_bytes() -> [u8; 32] { + let mut key = [0u8; 32]; + OsRng.fill_bytes(&mut key); + key + } + + /// Reset packet count (ideally should be called with a period of 1 second) + pub fn reset_count(&self) { + // The rate limiter is not very accurate, but at the scale we care about it + // doesn't matter much + let current_time = Instant::now(); + let mut last_reset_time = self.last_reset.lock(); + if current_time.duration_since(*last_reset_time).as_secs() >= RESET_PERIOD { + self.count.store(0, Ordering::SeqCst); + *last_reset_time = current_time; + } + } + + /// Compute the correct cookie value based on the current secret value and + /// the source IP + fn current_cookie(&self, addr: IpAddr) -> Cookie { + let mut addr_bytes = [0u8; 16]; + + match addr { + IpAddr::V4(a) => addr_bytes[..4].copy_from_slice(&a.octets()[..]), + IpAddr::V6(a) => addr_bytes[..].copy_from_slice(&a.octets()[..]), + } + + // The current cookie for a given IP is the + // MAC(responder.changing_secret_every_two_minutes, initiator.ip_address) + // First we derive the secret from the current time, the value of cur_counter + // would change with time. + let cur_counter = Instant::now().duration_since(self.start_time).as_secs() / COOKIE_REFRESH; + + // Next we derive the cookie + b2s_keyed_mac_16_2(&self.secret_key, &cur_counter.to_le_bytes(), &addr_bytes) + } + + fn nonce(&self) -> [u8; COOKIE_NONCE_SIZE] { + let ctr = self.nonce_ctr.fetch_add(1, Ordering::Relaxed); + + b2s_mac_24(&self.nonce_key, &ctr.to_le_bytes()) + } + + fn is_under_load(&self) -> bool { + self.count.fetch_add(1, Ordering::SeqCst) >= self.limit + } + + pub(crate) fn format_cookie_reply<'a>( + &self, + idx: u32, + cookie: Cookie, + mac1: &[u8], + dst: &'a mut [u8], + ) -> Result<&'a mut [u8], WireGuardError> { + if dst.len() < super::COOKIE_REPLY_SZ { + return Err(WireGuardError::DestinationBufferTooSmall) + } + + let (message_type, rest) = dst.split_at_mut(4); + let (receiver_index, rest) = rest.split_at_mut(4); + let (nonce, rest) = rest.split_at_mut(24); + let (encrypted_cookie, _) = rest.split_at_mut(16 + 16); + + // msg.message_type = 3 + // msg.reserved_zero = { 0, 0, 0 } + message_type.copy_from_slice(&super::COOKIE_REPLY.to_le_bytes()); + // msg.receiver_index = little_endian(initiator.sender_index) + receiver_index.copy_from_slice(&idx.to_le_bytes()); + nonce.copy_from_slice(&self.nonce()[..]); + + let cipher = XChaCha20Poly1305::new(&self.cookie_key); + + let iv = GenericArray::from_slice(nonce); + + encrypted_cookie[..16].copy_from_slice(&cookie); + let tag = cipher + .encrypt_in_place_detached(iv, mac1, &mut encrypted_cookie[..16]) + .map_err(|_| WireGuardError::DestinationBufferTooSmall)?; + + encrypted_cookie[16..].copy_from_slice(&tag); + + Ok(&mut dst[..super::COOKIE_REPLY_SZ]) + } + + /// Verify the MAC fields on the datagram, and apply rate limiting if needed + pub fn verify_packet<'a, 'b>( + &self, + src_addr: Option, + src: &'a [u8], + dst: &'b mut [u8], + ) -> Result, TunnResult<'b>> { + let packet = Tunnel::parse_incoming_packet(src)?; + + // Verify and rate limit handshake messages only + if let Packet::HandshakeInit(HandshakeInit { sender_idx, .. }) + | Packet::HandshakeResponse(HandshakeResponse { sender_idx, .. }) = packet + { + let (msg, macs) = src.split_at(src.len() - 32); + let (mac1, mac2) = macs.split_at(16); + + let computed_mac1 = b2s_keyed_mac_16(&self.mac1_key, msg); + verify_slices_are_equal(&computed_mac1[..16], mac1) + .map_err(|_| TunnResult::Err(WireGuardError::InvalidMac))?; + + if self.is_under_load() { + let addr = match src_addr { + None => return Err(TunnResult::Err(WireGuardError::UnderLoad)), + Some(addr) => addr, + }; + + // Only given an address can we validate mac2 + let cookie = self.current_cookie(addr); + let computed_mac2 = b2s_keyed_mac_16_2(&cookie, msg, mac1); + + if verify_slices_are_equal(&computed_mac2[..16], mac2).is_err() { + let cookie_packet = self + .format_cookie_reply(sender_idx, cookie, mac1, dst) + .map_err(TunnResult::Err)?; + return Err(TunnResult::WriteToNetwork(cookie_packet)) + } + } + } + + Ok(packet) + } +} diff --git a/burrow/src/wireguard/noise/session.rs b/burrow/src/wireguard/noise/session.rs new file mode 100755 index 0000000..f899b86 --- /dev/null +++ b/burrow/src/wireguard/noise/session.rs @@ -0,0 +1,279 @@ +// Copyright (c) 2019 Cloudflare, Inc. All rights reserved. +// SPDX-License-Identifier: BSD-3-Clause + +use std::sync::atomic::{AtomicUsize, Ordering}; + +use parking_lot::Mutex; +use ring::aead::{Aad, LessSafeKey, Nonce, UnboundKey, CHACHA20_POLY1305}; + +use super::{errors::WireGuardError, PacketData}; + +pub struct Session { + pub(crate) receiving_index: u32, + sending_index: u32, + receiver: LessSafeKey, + sender: LessSafeKey, + sending_key_counter: AtomicUsize, + receiving_key_counter: Mutex, +} + +impl std::fmt::Debug for Session { + fn fmt(&self, f: &mut std::fmt::Formatter) -> std::fmt::Result { + write!( + f, + "Session: {}<- ->{}", + self.receiving_index, self.sending_index + ) + } +} + +/// Where encrypted data resides in a data packet +const DATA_OFFSET: usize = 16; +/// The overhead of the AEAD +const AEAD_SIZE: usize = 16; + +// Receiving buffer constants +const WORD_SIZE: u64 = 64; +const N_WORDS: u64 = 16; // Suffice to reorder 64*16 = 1024 packets; can be increased at will +const N_BITS: u64 = WORD_SIZE * N_WORDS; + +#[derive(Debug, Clone, Default)] +struct ReceivingKeyCounterValidator { + /// In order to avoid replays while allowing for some reordering of the + /// packets, we keep a bitmap of received packets, and the value of the + /// highest counter + next: u64, + /// Used to estimate packet loss + receive_cnt: u64, + bitmap: [u64; N_WORDS as usize], +} + +impl ReceivingKeyCounterValidator { + #[inline(always)] + fn set_bit(&mut self, idx: u64) { + let bit_idx = idx % N_BITS; + let word = (bit_idx / WORD_SIZE) as usize; + let bit = (bit_idx % WORD_SIZE) as usize; + self.bitmap[word] |= 1 << bit; + } + + #[inline(always)] + fn clear_bit(&mut self, idx: u64) { + let bit_idx = idx % N_BITS; + let word = (bit_idx / WORD_SIZE) as usize; + let bit = (bit_idx % WORD_SIZE) as usize; + self.bitmap[word] &= !(1u64 << bit); + } + + /// Clear the word that contains idx + #[inline(always)] + fn clear_word(&mut self, idx: u64) { + let bit_idx = idx % N_BITS; + let word = (bit_idx / WORD_SIZE) as usize; + self.bitmap[word] = 0; + } + + /// Returns true if bit is set, false otherwise + #[inline(always)] + fn check_bit(&self, idx: u64) -> bool { + let bit_idx = idx % N_BITS; + let word = (bit_idx / WORD_SIZE) as usize; + let bit = (bit_idx % WORD_SIZE) as usize; + ((self.bitmap[word] >> bit) & 1) == 1 + } + + /// Returns true if the counter was not yet received, and is not too far + /// back + #[inline(always)] + fn will_accept(&self, counter: u64) -> Result<(), WireGuardError> { + if counter >= self.next { + // As long as the counter is growing no replay took place for sure + return Ok(()) + } + if counter + N_BITS < self.next { + // Drop if too far back + return Err(WireGuardError::InvalidCounter) + } + if !self.check_bit(counter) { + Ok(()) + } else { + Err(WireGuardError::DuplicateCounter) + } + } + + /// Marks the counter as received, and returns true if it is still good (in + /// case during decryption something changed) + #[inline(always)] + fn mark_did_receive(&mut self, counter: u64) -> Result<(), WireGuardError> { + if counter + N_BITS < self.next { + // Drop if too far back + return Err(WireGuardError::InvalidCounter) + } + if counter == self.next { + // Usually the packets arrive in order, in that case we simply mark the bit and + // increment the counter + self.set_bit(counter); + self.next += 1; + return Ok(()) + } + if counter < self.next { + // A packet arrived out of order, check if it is valid, and mark + if self.check_bit(counter) { + return Err(WireGuardError::InvalidCounter) + } + self.set_bit(counter); + return Ok(()) + } + // Packets where dropped, or maybe reordered, skip them and mark unused + if counter - self.next >= N_BITS { + // Too far ahead, clear all the bits + for c in self.bitmap.iter_mut() { + *c = 0; + } + } else { + let mut i = self.next; + while i % WORD_SIZE != 0 && i < counter { + // Clear until i aligned to word size + self.clear_bit(i); + i += 1; + } + while i + WORD_SIZE < counter { + // Clear whole word at a time + self.clear_word(i); + i = (i + WORD_SIZE) & 0u64.wrapping_sub(WORD_SIZE); + } + while i < counter { + // Clear any remaining bits + self.clear_bit(i); + i += 1; + } + } + self.set_bit(counter); + self.next = counter + 1; + Ok(()) + } +} + +impl Session { + pub(super) fn new( + local_index: u32, + peer_index: u32, + receiving_key: [u8; 32], + sending_key: [u8; 32], + ) -> Session { + Session { + receiving_index: local_index, + sending_index: peer_index, + receiver: LessSafeKey::new( + UnboundKey::new(&CHACHA20_POLY1305, &receiving_key).unwrap(), + ), + sender: LessSafeKey::new(UnboundKey::new(&CHACHA20_POLY1305, &sending_key).unwrap()), + sending_key_counter: AtomicUsize::new(0), + receiving_key_counter: Mutex::new(Default::default()), + } + } + + pub(super) fn local_index(&self) -> usize { + self.receiving_index as usize + } + + /// Returns true if receiving counter is good to use + fn receiving_counter_quick_check(&self, counter: u64) -> Result<(), WireGuardError> { + let counter_validator = self.receiving_key_counter.lock(); + counter_validator.will_accept(counter) + } + + /// Returns true if receiving counter is good to use, and marks it as used { + fn receiving_counter_mark(&self, counter: u64) -> Result<(), WireGuardError> { + let mut counter_validator = self.receiving_key_counter.lock(); + let ret = counter_validator.mark_did_receive(counter); + if ret.is_ok() { + counter_validator.receive_cnt += 1; + } + ret + } + + /// src - an IP packet from the interface + /// dst - pre-allocated space to hold the encapsulating UDP packet to send + /// over the network returns the size of the formatted packet + pub(super) fn format_packet_data<'a>(&self, src: &[u8], dst: &'a mut [u8]) -> &'a mut [u8] { + if dst.len() < src.len() + super::DATA_OVERHEAD_SZ { + panic!("The destination buffer is too small"); + } + + let sending_key_counter = self.sending_key_counter.fetch_add(1, Ordering::Relaxed) as u64; + + let (message_type, rest) = dst.split_at_mut(4); + let (receiver_index, rest) = rest.split_at_mut(4); + let (counter, data) = rest.split_at_mut(8); + + message_type.copy_from_slice(&super::DATA.to_le_bytes()); + receiver_index.copy_from_slice(&self.sending_index.to_le_bytes()); + counter.copy_from_slice(&sending_key_counter.to_le_bytes()); + + // TODO: spec requires padding to 16 bytes, but actually works fine without it + let n = { + let mut nonce = [0u8; 12]; + nonce[4..12].copy_from_slice(&sending_key_counter.to_le_bytes()); + data[..src.len()].copy_from_slice(src); + self.sender + .seal_in_place_separate_tag( + Nonce::assume_unique_for_key(nonce), + Aad::from(&[]), + &mut data[..src.len()], + ) + .map(|tag| { + data[src.len()..src.len() + AEAD_SIZE].copy_from_slice(tag.as_ref()); + src.len() + AEAD_SIZE + }) + .unwrap() + }; + + &mut dst[..DATA_OFFSET + n] + } + + /// packet - a data packet we received from the network + /// dst - pre-allocated space to hold the encapsulated IP packet, to send to + /// the interface dst will always take less space than src + /// return the size of the encapsulated packet on success + pub(super) fn receive_packet_data<'a>( + &self, + packet: PacketData, + dst: &'a mut [u8], + ) -> Result<&'a mut [u8], WireGuardError> { + let ct_len = packet.encrypted_encapsulated_packet.len(); + if dst.len() < ct_len { + // This is a very incorrect use of the library, therefore panic and not error + panic!("The destination buffer is too small"); + } + if packet.receiver_idx != self.receiving_index { + return Err(WireGuardError::WrongIndex) + } + // Don't reuse counters, in case this is a replay attack we want to quickly + // check the counter without running expensive decryption + self.receiving_counter_quick_check(packet.counter)?; + + let ret = { + let mut nonce = [0u8; 12]; + nonce[4..12].copy_from_slice(&packet.counter.to_le_bytes()); + dst[..ct_len].copy_from_slice(packet.encrypted_encapsulated_packet); + self.receiver + .open_in_place( + Nonce::assume_unique_for_key(nonce), + Aad::from(&[]), + &mut dst[..ct_len], + ) + .map_err(|_| WireGuardError::InvalidAeadTag)? + }; + + // After decryption is done, check counter again, and mark as received + self.receiving_counter_mark(packet.counter)?; + Ok(ret) + } + + /// Returns the estimated downstream packet loss for this session + pub(super) fn current_packet_cnt(&self) -> (u64, u64) { + let counter_validator = self.receiving_key_counter.lock(); + (counter_validator.next, counter_validator.receive_cnt) + } +} diff --git a/burrow/src/wireguard/noise/timers.rs b/burrow/src/wireguard/noise/timers.rs new file mode 100755 index 0000000..1d0cf1f --- /dev/null +++ b/burrow/src/wireguard/noise/timers.rs @@ -0,0 +1,333 @@ +// Copyright (c) 2019 Cloudflare, Inc. All rights reserved. +// SPDX-License-Identifier: BSD-3-Clause + +use std::{ + mem, + ops::{Index, IndexMut}, + time::{Duration, Instant}, +}; + +use super::{errors::WireGuardError, TunnResult, Tunnel}; + +// Some constants, represent time in seconds +// https://www.wireguard.com/papers/wireguard.pdf#page=14 +pub(crate) const REKEY_AFTER_TIME: Duration = Duration::from_secs(120); +const REJECT_AFTER_TIME: Duration = Duration::from_secs(180); +const REKEY_ATTEMPT_TIME: Duration = Duration::from_secs(90); +pub(crate) const REKEY_TIMEOUT: Duration = Duration::from_secs(5); +const KEEPALIVE_TIMEOUT: Duration = Duration::from_secs(10); +const COOKIE_EXPIRATION_TIME: Duration = Duration::from_secs(120); + +#[derive(Debug)] +pub enum TimerName { + /// Current time, updated each call to `update_timers` + TimeCurrent, + /// Time when last handshake was completed + TimeSessionEstablished, + /// Time the last attempt for a new handshake began + TimeLastHandshakeStarted, + /// Time we last received and authenticated a packet + TimeLastPacketReceived, + /// Time we last send a packet + TimeLastPacketSent, + /// Time we last received and authenticated a DATA packet + TimeLastDataPacketReceived, + /// Time we last send a DATA packet + TimeLastDataPacketSent, + /// Time we last received a cookie + TimeCookieReceived, + /// Time we last sent persistent keepalive + TimePersistentKeepalive, + Top, +} + +use self::TimerName::*; + +#[derive(Debug)] +pub struct Timers { + /// Is the owner of the timer the initiator or the responder for the last + /// handshake? + is_initiator: bool, + /// Start time of the tunnel + time_started: Instant, + timers: [Duration; TimerName::Top as usize], + pub(super) session_timers: [Duration; super::N_SESSIONS], + /// Did we receive data without sending anything back? + want_keepalive: bool, + /// Did we send data without hearing back? + want_handshake: bool, + persistent_keepalive: usize, + /// Should this timer call reset rr function (if not a shared rr instance) + pub(super) should_reset_rr: bool, +} + +impl Timers { + pub(super) fn new(persistent_keepalive: Option, reset_rr: bool) -> Timers { + Timers { + is_initiator: false, + time_started: Instant::now(), + timers: Default::default(), + session_timers: Default::default(), + want_keepalive: Default::default(), + want_handshake: Default::default(), + persistent_keepalive: usize::from(persistent_keepalive.unwrap_or(0)), + should_reset_rr: reset_rr, + } + } + + fn is_initiator(&self) -> bool { + self.is_initiator + } + + // We don't really clear the timers, but we set them to the current time to + // so the reference time frame is the same + pub(super) fn clear(&mut self) { + let now = Instant::now().duration_since(self.time_started); + for t in &mut self.timers[..] { + *t = now; + } + self.want_handshake = false; + self.want_keepalive = false; + } +} + +impl Index for Timers { + type Output = Duration; + + fn index(&self, index: TimerName) -> &Duration { + &self.timers[index as usize] + } +} + +impl IndexMut for Timers { + fn index_mut(&mut self, index: TimerName) -> &mut Duration { + &mut self.timers[index as usize] + } +} + +impl Tunnel { + pub(super) fn timer_tick(&mut self, timer_name: TimerName) { + match timer_name { + TimeLastPacketReceived => { + self.timers.want_keepalive = true; + self.timers.want_handshake = false; + } + TimeLastPacketSent => { + self.timers.want_handshake = true; + self.timers.want_keepalive = false; + } + _ => {} + } + + let time = self.timers[TimeCurrent]; + self.timers[timer_name] = time; + } + + pub(super) fn timer_tick_session_established( + &mut self, + is_initiator: bool, + session_idx: usize, + ) { + self.timer_tick(TimeSessionEstablished); + self.timers.session_timers[session_idx % super::N_SESSIONS] = self.timers[TimeCurrent]; + self.timers.is_initiator = is_initiator; + } + + // We don't really clear the timers, but we set them to the current time to + // so the reference time frame is the same + fn clear_all(&mut self) { + for session in &mut self.sessions { + *session = None; + } + + self.packet_queue.clear(); + + self.timers.clear(); + } + + fn update_session_timers(&mut self, time_now: Duration) { + let timers = &mut self.timers; + + for (i, t) in timers.session_timers.iter_mut().enumerate() { + if time_now - *t > REJECT_AFTER_TIME { + if let Some(session) = self.sessions[i].take() { + tracing::debug!( + message = "SESSION_EXPIRED(REJECT_AFTER_TIME)", + session = session.receiving_index + ); + } + *t = time_now; + } + } + } + + pub fn update_timers<'a>(&mut self, dst: &'a mut [u8]) -> TunnResult<'a> { + let mut handshake_initiation_required = false; + let mut keepalive_required = false; + + let time = Instant::now(); + + if self.timers.should_reset_rr { + self.rate_limiter.reset_count(); + } + + // All the times are counted from tunnel initiation, for efficiency our timers + // are rounded to a second, as there is no real benefit to having highly + // accurate timers. + let now = time.duration_since(self.timers.time_started); + self.timers[TimeCurrent] = now; + + self.update_session_timers(now); + + // Load timers only once: + let session_established = self.timers[TimeSessionEstablished]; + let handshake_started = self.timers[TimeLastHandshakeStarted]; + let aut_packet_received = self.timers[TimeLastPacketReceived]; + let aut_packet_sent = self.timers[TimeLastPacketSent]; + let data_packet_received = self.timers[TimeLastDataPacketReceived]; + let data_packet_sent = self.timers[TimeLastDataPacketSent]; + let persistent_keepalive = self.timers.persistent_keepalive; + + { + if self.handshake.is_expired() { + return TunnResult::Err(WireGuardError::ConnectionExpired) + } + + // Clear cookie after COOKIE_EXPIRATION_TIME + if self.handshake.has_cookie() + && now - self.timers[TimeCookieReceived] >= COOKIE_EXPIRATION_TIME + { + self.handshake.clear_cookie(); + } + + // All ephemeral private keys and symmetric session keys are zeroed out after + // (REJECT_AFTER_TIME * 3) ms if no new keys have been exchanged. + if now - session_established >= REJECT_AFTER_TIME * 3 { + tracing::error!("CONNECTION_EXPIRED(REJECT_AFTER_TIME * 3)"); + self.handshake.set_expired(); + self.clear_all(); + return TunnResult::Err(WireGuardError::ConnectionExpired) + } + + if let Some(time_init_sent) = self.handshake.timer() { + // Handshake Initiation Retransmission + if now - handshake_started >= REKEY_ATTEMPT_TIME { + // After REKEY_ATTEMPT_TIME ms of trying to initiate a new handshake, + // the retries give up and cease, and clear all existing packets queued + // up to be sent. If a packet is explicitly queued up to be sent, then + // this timer is reset. + tracing::error!("CONNECTION_EXPIRED(REKEY_ATTEMPT_TIME)"); + self.handshake.set_expired(); + self.clear_all(); + return TunnResult::Err(WireGuardError::ConnectionExpired) + } + + if time_init_sent.elapsed() >= REKEY_TIMEOUT { + // We avoid using `time` here, because it can be earlier than `time_init_sent`. + // Once `checked_duration_since` is stable we can use that. + // A handshake initiation is retried after REKEY_TIMEOUT + jitter ms, + // if a response has not been received, where jitter is some random + // value between 0 and 333 ms. + tracing::warn!("HANDSHAKE(REKEY_TIMEOUT)"); + handshake_initiation_required = true; + } + } else { + if self.timers.is_initiator() { + // After sending a packet, if the sender was the original initiator + // of the handshake and if the current session key is REKEY_AFTER_TIME + // ms old, we initiate a new handshake. If the sender was the original + // responder of the handshake, it does not re-initiate a new handshake + // after REKEY_AFTER_TIME ms like the original initiator does. + if session_established < data_packet_sent + && now - session_established >= REKEY_AFTER_TIME + { + tracing::debug!("HANDSHAKE(REKEY_AFTER_TIME (on send))"); + handshake_initiation_required = true; + } + + // After receiving a packet, if the receiver was the original initiator + // of the handshake and if the current session key is REJECT_AFTER_TIME + // - KEEPALIVE_TIMEOUT - REKEY_TIMEOUT ms old, we initiate a new + // handshake. + if session_established < data_packet_received + && now - session_established + >= REJECT_AFTER_TIME - KEEPALIVE_TIMEOUT - REKEY_TIMEOUT + { + tracing::warn!( + "HANDSHAKE(REJECT_AFTER_TIME - KEEPALIVE_TIMEOUT - \ + REKEY_TIMEOUT \ + (on receive))" + ); + handshake_initiation_required = true; + } + } + + // If we have sent a packet to a given peer but have not received a + // packet after from that peer for (KEEPALIVE + REKEY_TIMEOUT) ms, + // we initiate a new handshake. + if data_packet_sent > aut_packet_received + && now - aut_packet_received >= KEEPALIVE_TIMEOUT + REKEY_TIMEOUT + && mem::replace(&mut self.timers.want_handshake, false) + { + tracing::warn!("HANDSHAKE(KEEPALIVE + REKEY_TIMEOUT)"); + handshake_initiation_required = true; + } + + if !handshake_initiation_required { + // If a packet has been received from a given peer, but we have not sent one + // back to the given peer in KEEPALIVE ms, we send an empty + // packet. + if data_packet_received > aut_packet_sent + && now - aut_packet_sent >= KEEPALIVE_TIMEOUT + && mem::replace(&mut self.timers.want_keepalive, false) + { + tracing::debug!("KEEPALIVE(KEEPALIVE_TIMEOUT)"); + keepalive_required = true; + } + + // Persistent KEEPALIVE + if persistent_keepalive > 0 + && (now - self.timers[TimePersistentKeepalive] + >= Duration::from_secs(persistent_keepalive as _)) + { + tracing::debug!("KEEPALIVE(PERSISTENT_KEEPALIVE)"); + self.timer_tick(TimePersistentKeepalive); + keepalive_required = true; + } + } + } + } + + if handshake_initiation_required { + return self.format_handshake_initiation(dst, true) + } + + if keepalive_required { + return self.encapsulate(&[], dst) + } + + TunnResult::Done + } + + pub fn time_since_last_handshake(&self) -> Option { + let current_session = self.current; + if self.sessions[current_session % super::N_SESSIONS].is_some() { + let duration_since_tun_start = Instant::now().duration_since(self.timers.time_started); + let duration_since_session_established = self.timers[TimeSessionEstablished]; + + Some(duration_since_tun_start - duration_since_session_established) + } else { + None + } + } + + pub fn persistent_keepalive(&self) -> Option { + let keepalive = self.timers.persistent_keepalive; + + if keepalive > 0 { + Some(keepalive as u16) + } else { + None + } + } +} diff --git a/burrow/src/wireguard/pcb.rs b/burrow/src/wireguard/pcb.rs new file mode 100755 index 0000000..9e12468 --- /dev/null +++ b/burrow/src/wireguard/pcb.rs @@ -0,0 +1,88 @@ +use std::net::SocketAddr; + +use anyhow::Error; +use fehler::throws; +use ip_network::IpNetwork; +use tokio::{net::UdpSocket, task::JoinHandle}; + +use super::{ + iface::PacketInterface, + noise::{TunnResult, Tunnel}, + Peer, +}; + +#[derive(Debug)] +pub struct PeerPcb { + pub endpoint: SocketAddr, + pub allowed_ips: Vec, + pub handle: Option>, + socket: Option, + tunnel: Tunnel, +} + +impl PeerPcb { + #[throws] + pub fn new(peer: Peer) -> Self { + let tunnel = Tunnel::new(peer.private_key, peer.public_key, None, None, 1, None) + .map_err(|s| anyhow::anyhow!("{}", s))?; + + Self { + endpoint: peer.endpoint, + allowed_ips: peer.allowed_ips, + handle: None, + socket: None, + tunnel, + } + } + + async fn open_if_closed(&mut self) -> Result<(), Error> { + if self.socket.is_none() { + let socket = UdpSocket::bind("0.0.0.0:0").await?; + socket.connect(self.endpoint).await?; + self.socket = Some(socket); + } + Ok(()) + } + + pub async fn run(&self, interface: Box<&dyn PacketInterface>) -> Result<(), Error> { + let mut buf = [0u8; 3000]; + loop { + let Some(socket) = self.socket.as_ref() else { + continue + }; + + let packet = match socket.recv(&mut buf).await { + Ok(s) => &buf[..s], + Err(e) => { + tracing::error!("eror receiving on peer socket: {}", e); + continue + } + }; + + let (len, addr) = socket.recv_from(&mut buf).await?; + + tracing::debug!("received {} bytes from {}", len, addr); + } + } + + pub async fn socket(&mut self) -> Result<&UdpSocket, Error> { + self.open_if_closed().await?; + Ok(self.socket.as_ref().expect("socket was just opened")) + } + + pub async fn send(&mut self, src: &[u8]) -> Result<(), Error> { + let mut dst_buf = [0u8; 3000]; + match self.tunnel.encapsulate(src, &mut dst_buf[..]) { + TunnResult::Done => {} + TunnResult::Err(e) => { + tracing::error!(message = "Encapsulate error", error = ?e) + } + TunnResult::WriteToNetwork(packet) => { + let socket = self.socket().await?; + socket.send(packet).await?; + } + _ => panic!("Unexpected result from encapsulate"), + }; + Ok(()) + } +} diff --git a/burrow/src/wireguard/peer.rs b/burrow/src/wireguard/peer.rs new file mode 100755 index 0000000..8a74ce1 --- /dev/null +++ b/burrow/src/wireguard/peer.rs @@ -0,0 +1,23 @@ +use std::{fmt, net::SocketAddr}; + +use anyhow::Error; +use fehler::throws; +use ip_network::IpNetwork; +use x25519_dalek::{PublicKey, StaticSecret}; + +pub struct Peer { + pub endpoint: SocketAddr, + pub private_key: StaticSecret, + pub public_key: PublicKey, + pub allowed_ips: Vec, +} + +impl fmt::Debug for Peer { + fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { + f.debug_struct("Peer") + .field("endpoint", &self.endpoint) + .field("public_key", &self.public_key) + .field("allowed_ips", &self.allowed_ips) + .finish() + } +} diff --git a/tun/build.rs b/tun/build.rs index 5569cc4..8da8a40 100644 --- a/tun/build.rs +++ b/tun/build.rs @@ -26,7 +26,7 @@ async fn generate(out_dir: &std::path::Path) -> anyhow::Result<()> { println!("cargo:rerun-if-changed={}", binary_path.to_str().unwrap()); if let (Ok(..), Ok(..)) = (File::open(&bindings_path), File::open(&binary_path)) { - return Ok(()); + return Ok(()) }; let archive = download(out_dir) @@ -80,9 +80,10 @@ async fn download(directory: &std::path::Path) -> anyhow::Result #[cfg(windows)] fn parse(file: std::fs::File) -> anyhow::Result<(bindgen::Bindings, Vec)> { - use anyhow::Context; use std::io::Read; + use anyhow::Context; + let reader = std::io::BufReader::new(file); let mut archive = zip::ZipArchive::new(reader)?; diff --git a/tun/src/lib.rs b/tun/src/lib.rs index 151c10d..a1ca636 100644 --- a/tun/src/lib.rs +++ b/tun/src/lib.rs @@ -2,11 +2,11 @@ #[cfg(target_os = "windows")] #[path = "windows/mod.rs"] -mod imp; +mod os_imp; #[cfg(any(target_os = "linux", target_vendor = "apple"))] #[path = "unix/mod.rs"] -pub(crate) mod imp; +pub(crate) mod os_imp; mod options; @@ -14,5 +14,5 @@ mod options; #[cfg(feature = "tokio")] pub mod tokio; -pub use imp::{TunInterface, TunQueue}; pub use options::TunOptions; +pub use os_imp::{TunInterface, TunQueue}; diff --git a/tun/src/options.rs b/tun/src/options.rs index e766be8..13493db 100644 --- a/tun/src/options.rs +++ b/tun/src/options.rs @@ -1,6 +1,7 @@ -use fehler::throws; use std::io::Error; +use fehler::throws; + use super::TunInterface; #[derive(Debug, Clone, Default)] @@ -15,25 +16,17 @@ pub struct TunOptions { } impl TunOptions { - pub fn new() -> Self { - Self::default() - } + pub fn new() -> Self { Self::default() } pub fn name(mut self, name: &str) -> Self { self.name = Some(name.to_owned()); self } - pub fn no_pi(mut self, enable: bool) { - self.no_pi = enable.then_some(()); - } + pub fn no_pi(mut self, enable: bool) { self.no_pi = enable.then_some(()); } - pub fn tun_excl(mut self, enable: bool) { - self.tun_excl = enable.then_some(()); - } + pub fn tun_excl(mut self, enable: bool) { self.tun_excl = enable.then_some(()); } #[throws] - pub fn open(self) -> TunInterface { - TunInterface::new_with_options(self)? - } + pub fn open(self) -> TunInterface { TunInterface::new_with_options(self)? } } diff --git a/tun/src/tokio/mod.rs b/tun/src/tokio/mod.rs index 7828279..8318830 100644 --- a/tun/src/tokio/mod.rs +++ b/tun/src/tokio/mod.rs @@ -1,4 +1,5 @@ use std::io; + use tokio::io::unix::AsyncFd; use tracing::instrument; @@ -16,7 +17,7 @@ impl TunInterface { } #[instrument] - pub async fn write(&self, buf: &[u8]) -> io::Result { + pub async fn send(&self, buf: &[u8]) -> io::Result { loop { let mut guard = self.inner.writable().await?; match guard.try_io(|inner| inner.get_ref().send(buf)) { @@ -27,7 +28,7 @@ impl TunInterface { } #[instrument] - pub async fn read(&mut self, buf: &mut [u8]) -> io::Result { + pub async fn recv(&mut self, buf: &mut [u8]) -> io::Result { loop { let mut guard = self.inner.readable_mut().await?; match guard.try_io(|inner| (*inner).get_mut().recv(buf)) { diff --git a/tun/src/unix/apple/kern_control.rs b/tun/src/unix/apple/kern_control.rs index abc1e04..76e576f 100644 --- a/tun/src/unix/apple/kern_control.rs +++ b/tun/src/unix/apple/kern_control.rs @@ -1,7 +1,6 @@ +use std::{io::Error, mem::size_of, os::unix::io::AsRawFd}; + use fehler::throws; -use std::io::Error; -use std::mem::size_of; -use std::os::unix::io::AsRawFd; use super::sys; @@ -16,10 +15,7 @@ pub trait SysControlSocket { impl SysControlSocket for socket2::Socket { #[throws] fn resolve(&self, name: &str, index: u32) -> socket2::SockAddr { - let mut info = sys::ctl_info { - ctl_id: 0, - ctl_name: [0; 96], - }; + let mut info = sys::ctl_info { ctl_id: 0, ctl_name: [0; 96] }; info.ctl_name[..name.len()].copy_from_slice(name.as_bytes()); unsafe { sys::resolve_ctl_info(self.as_raw_fd(), &mut info as *mut sys::ctl_info)? }; @@ -28,7 +24,7 @@ impl SysControlSocket for socket2::Socket { socket2::SockAddr::init(|addr_storage, len| { *len = size_of::() as u32; - let mut addr: &mut sys::sockaddr_ctl = &mut *addr_storage.cast(); + let addr: &mut sys::sockaddr_ctl = &mut *addr_storage.cast(); addr.sc_len = *len as u8; addr.sc_family = sys::AF_SYSTEM as u8; addr.ss_sysaddr = sys::AF_SYS_CONTROL as u16; diff --git a/tun/src/unix/apple/mod.rs b/tun/src/unix/apple/mod.rs index f4fd1e2..83dbdc1 100644 --- a/tun/src/unix/apple/mod.rs +++ b/tun/src/unix/apple/mod.rs @@ -1,22 +1,24 @@ +use std::{ + io::{Error, IoSlice}, + mem, + net::{Ipv4Addr, SocketAddrV4}, + os::fd::{AsRawFd, RawFd}, +}; + use byteorder::{ByteOrder, NetworkEndian}; use fehler::throws; use libc::{c_char, iovec, writev, AF_INET, AF_INET6}; -use tracing::info; use socket2::{Domain, SockAddr, Socket, Type}; -use std::io::IoSlice; -use std::net::{Ipv4Addr, SocketAddrV4}; -use std::os::fd::{AsRawFd, RawFd}; -use std::{io::Error, mem}; -use tracing::instrument; +use tracing::{self, instrument}; mod kern_control; mod sys; -pub use super::queue::TunQueue; - -use super::{ifname_to_string, string_to_ifname, TunOptions}; use kern_control::SysControlSocket; +pub use super::queue::TunQueue; +use super::{ifname_to_string, string_to_ifname, TunOptions}; + #[derive(Debug)] pub struct TunInterface { pub(crate) socket: socket2::Socket, @@ -81,7 +83,7 @@ impl TunInterface { let mut iff = self.ifreq()?; iff.ifr_ifru.ifru_addr = unsafe { *addr.as_ptr() }; self.perform(|fd| unsafe { sys::if_set_addr(fd, &iff) })?; - info!("ipv4_addr_set: {:?} (fd: {:?})", addr, self.as_raw_fd()) + tracing::info!("ipv4_addr_set: {:?} (fd: {:?})", addr, self.as_raw_fd()) } #[throws] @@ -118,7 +120,7 @@ impl TunInterface { let mut iff = self.ifreq()?; iff.ifr_ifru.ifru_mtu = mtu; self.perform(|fd| unsafe { sys::if_set_mtu(fd, &iff) })?; - info!("mtu_set: {:?} (fd: {:?})", mtu, self.as_raw_fd()) + tracing::info!("mtu_set: {:?} (fd: {:?})", mtu, self.as_raw_fd()) } #[throws] @@ -140,7 +142,7 @@ impl TunInterface { let mut iff = self.ifreq()?; iff.ifr_ifru.ifru_netmask = unsafe { *addr.as_ptr() }; self.perform(|fd| unsafe { sys::if_set_netmask(fd, &iff) })?; - info!( + tracing::info!( "netmask_set: {:?} (fd: {:?})", unsafe { iff.ifr_ifru.ifru_netmask }, self.as_raw_fd() diff --git a/tun/src/unix/apple/sys.rs b/tun/src/unix/apple/sys.rs index c0ea613..b4d4a6a 100644 --- a/tun/src/unix/apple/sys.rs +++ b/tun/src/unix/apple/sys.rs @@ -2,11 +2,20 @@ use std::mem; use libc::{c_char, c_int, c_short, c_uint, c_ulong, sockaddr}; pub use libc::{ - c_void, sockaddr_ctl, sockaddr_in, socklen_t, AF_SYSTEM, AF_SYS_CONTROL, IFNAMSIZ, + c_void, + sockaddr_ctl, + sockaddr_in, + socklen_t, + AF_SYSTEM, + AF_SYS_CONTROL, + IFNAMSIZ, SYSPROTO_CONTROL, }; use nix::{ - ioctl_read_bad, ioctl_readwrite, ioctl_write_ptr_bad, request_code_readwrite, + ioctl_read_bad, + ioctl_readwrite, + ioctl_write_ptr_bad, + request_code_readwrite, request_code_write, }; diff --git a/tun/src/unix/linux/mod.rs b/tun/src/unix/linux/mod.rs index 75bb9d2..90cf353 100644 --- a/tun/src/unix/linux/mod.rs +++ b/tun/src/unix/linux/mod.rs @@ -1,16 +1,18 @@ +use std::{ + fs::OpenOptions, + io::{Error, Write}, + mem, + net::{Ipv4Addr, Ipv6Addr, SocketAddrV4}, + os::{ + fd::RawFd, + unix::io::{AsRawFd, FromRawFd, IntoRawFd}, + }, +}; + use fehler::throws; - -use socket2::{Domain, SockAddr, Socket, Type}; -use std::fs::OpenOptions; -use std::io::{Error, Write}; -use std::mem; -use std::net::{Ipv4Addr, Ipv6Addr, SocketAddrV4}; -use std::os::fd::RawFd; -use std::os::unix::io::{AsRawFd, FromRawFd, IntoRawFd}; - -use tracing::{info, instrument}; - use libc::in6_ifreq; +use socket2::{Domain, SockAddr, Socket, Type}; +use tracing::{info, instrument}; use super::{ifname_to_string, string_to_ifname, TunOptions}; @@ -24,9 +26,7 @@ pub struct TunInterface { impl TunInterface { #[throws] #[instrument] - pub fn new() -> TunInterface { - Self::new_with_options(TunOptions::new())? - } + pub fn new() -> TunInterface { Self::new_with_options(TunOptions::new())? } #[throws] #[instrument] @@ -212,7 +212,5 @@ impl TunInterface { #[throws] #[instrument] - pub fn send(&self, buf: &[u8]) -> usize { - self.socket.send(buf)? - } + pub fn send(&self, buf: &[u8]) -> usize { self.socket.send(buf)? } } diff --git a/tun/src/unix/linux/sys.rs b/tun/src/unix/linux/sys.rs index 8d8725b..e12c8ec 100644 --- a/tun/src/unix/linux/sys.rs +++ b/tun/src/unix/linux/sys.rs @@ -1,10 +1,7 @@ -use nix::{ioctl_read_bad, ioctl_write_ptr_bad, request_code_read, request_code_write}; use std::mem::size_of; -pub use libc::ifreq; -pub use libc::sockaddr; -pub use libc::sockaddr_in; -pub use libc::sockaddr_in6; +pub use libc::{ifreq, sockaddr, sockaddr_in, sockaddr_in6}; +use nix::{ioctl_read_bad, ioctl_write_ptr_bad, request_code_read, request_code_write}; ioctl_write_ptr_bad!( tun_set_iff, diff --git a/tun/src/unix/mod.rs b/tun/src/unix/mod.rs index 9da4204..269f4e6 100644 --- a/tun/src/unix/mod.rs +++ b/tun/src/unix/mod.rs @@ -2,6 +2,7 @@ use std::{ io::{Error, Read}, os::fd::{AsRawFd, FromRawFd, IntoRawFd, RawFd}, }; + use tracing::instrument; use super::TunOptions; @@ -28,9 +29,8 @@ impl AsRawFd for TunInterface { impl FromRawFd for TunInterface { unsafe fn from_raw_fd(fd: RawFd) -> TunInterface { - TunInterface { - socket: socket2::Socket::from_raw_fd(fd), - } + let socket = socket2::Socket::from_raw_fd(fd); + TunInterface { socket } } } @@ -65,4 +65,4 @@ pub fn string_to_ifname(name: &str) -> [libc::c_char; libc::IFNAMSIZ] { let len = name.len().min(buf.len()); buf[..len].copy_from_slice(unsafe { &*(name.as_bytes() as *const _ as *const [libc::c_char]) }); buf -} \ No newline at end of file +} diff --git a/tun/src/unix/queue.rs b/tun/src/unix/queue.rs index 923f926..879dcd5 100644 --- a/tun/src/unix/queue.rs +++ b/tun/src/unix/queue.rs @@ -1,10 +1,10 @@ -use fehler::throws; - use std::{ io::{Error, Read, Write}, mem::MaybeUninit, os::unix::io::{AsRawFd, IntoRawFd, RawFd}, }; + +use fehler::throws; use tracing::instrument; use crate::TunInterface; @@ -15,10 +15,9 @@ pub struct TunQueue { } impl TunQueue { - #[throws] #[instrument] - pub fn recv(&self, buf: &mut [MaybeUninit]) -> usize { - self.socket.recv(buf)? + pub fn recv(&self, buf: &mut [MaybeUninit]) -> Result { + self.socket.recv(buf) } } @@ -43,9 +42,7 @@ impl Write for TunQueue { impl From for TunQueue { fn from(interface: TunInterface) -> TunQueue { - TunQueue { - socket: interface.socket, - } + TunQueue { socket: interface.socket } } } diff --git a/tun/src/windows/mod.rs b/tun/src/windows/mod.rs index bae75c0..9b6d5ad 100644 --- a/tun/src/windows/mod.rs +++ b/tun/src/windows/mod.rs @@ -1,15 +1,14 @@ -use std::fmt::Debug; +use std::{fmt::Debug, io::Error, ptr}; + use fehler::throws; -use std::io::Error; -use std::ptr; use widestring::U16CString; use windows::Win32::Foundation::GetLastError; mod queue; -use super::TunOptions; - pub use queue::TunQueue; +use super::TunOptions; + pub struct TunInterface { handle: sys::WINTUN_ADAPTER_HANDLE, name: String, @@ -26,9 +25,7 @@ impl Debug for TunInterface { impl TunInterface { #[throws] - pub fn new() -> TunInterface { - Self::new_with_options(TunOptions::new())? - } + pub fn new() -> TunInterface { Self::new_with_options(TunOptions::new())? } #[throws] pub(crate) fn new_with_options(options: TunOptions) -> TunInterface { @@ -46,15 +43,11 @@ impl TunInterface { } } - pub fn name(&self) -> String { - self.name.clone() - } + pub fn name(&self) -> String { self.name.clone() } } impl Drop for TunInterface { - fn drop(&mut self) { - unsafe { sys::WINTUN.WintunCloseAdapter(self.handle) } - } + fn drop(&mut self) { unsafe { sys::WINTUN.WintunCloseAdapter(self.handle) } } } pub(crate) mod sys { diff --git a/tun/tests/configure.rs b/tun/tests/configure.rs index 0f1199d..6ef597b 100644 --- a/tun/tests/configure.rs +++ b/tun/tests/configure.rs @@ -1,13 +1,11 @@ +use std::{io::Error, net::Ipv4Addr}; + use fehler::throws; -use std::io::Error; -use std::net::Ipv4Addr; use tun::TunInterface; #[test] #[throws] -fn test_create() { - TunInterface::new()?; -} +fn test_create() { TunInterface::new()?; } #[test] #[throws] diff --git a/tun/tests/packets.rs b/tun/tests/packets.rs index b160893..91ebfba 100644 --- a/tun/tests/packets.rs +++ b/tun/tests/packets.rs @@ -1,7 +1,6 @@ -use fehler::throws; -use std::io::Error; +use std::{io::Error, net::Ipv4Addr}; -use std::net::Ipv4Addr; +use fehler::throws; use tun::TunInterface; #[throws] @@ -9,8 +8,8 @@ use tun::TunInterface; #[ignore = "requires interactivity"] #[cfg(not(target_os = "windows"))] fn tst_read() { - // This test is interactive, you need to send a packet to any server through 192.168.1.10 - // EG. `sudo route add 8.8.8.8 192.168.1.10`, + // This test is interactive, you need to send a packet to any server through + // 192.168.1.10 EG. `sudo route add 8.8.8.8 192.168.1.10`, //`dig @8.8.8.8 hackclub.com` let mut tun = TunInterface::new()?; println!("tun name: {:?}", tun.name()?); diff --git a/tun/tests/tokio.rs b/tun/tests/tokio.rs index e745c27..f7cb273 100644 --- a/tun/tests/tokio.rs +++ b/tun/tests/tokio.rs @@ -4,7 +4,7 @@ use std::net::Ipv4Addr; #[cfg(all(feature = "tokio", not(target_os = "windows")))] async fn test_create() { let tun = tun::TunInterface::new().unwrap(); - let async_tun = tun::tokio::TunInterface::new(tun).unwrap(); + let _ = tun::tokio::TunInterface::new(tun).unwrap(); } #[tokio::test] @@ -17,6 +17,6 @@ async fn test_write() { let async_tun = tun::tokio::TunInterface::new(tun).unwrap(); let mut buf = [0u8; 1500]; buf[0] = 6 << 4; - let bytes_written = async_tun.write(&buf).await.unwrap(); + let bytes_written = async_tun.send(&buf).await.unwrap(); assert!(bytes_written > 0); }