Squashed 'components/mycelium/' content from commit afb32e0

git-subtree-dir: components/mycelium
git-subtree-split: afb32e0cdb2d4cdd17f22a5693278068d061f08c
This commit is contained in:
2025-08-16 21:12:34 +02:00
commit 10025f9fa5
132 changed files with 50951 additions and 0 deletions

79
mycelium/Cargo.toml Normal file
View File

@@ -0,0 +1,79 @@
[package]
name = "mycelium"
version = "0.6.1"
edition = "2021"
license-file = "../LICENSE"
readme = "../README.md"
[features]
message = []
private-network = ["dep:openssl", "dep:tokio-openssl"]
vendored-openssl = ["openssl/vendored"]
mactunfd = [
"tun/appstore",
] #mactunfd is a flag to specify that macos should provide tun FD instead of tun name
[dependencies]
cdn-meta = { git = "https://github.com/threefoldtech/mycelium-cdn-registry", package = "cdn-meta" }
tokio = { version = "1.46.1", features = [
"io-util",
"fs",
"macros",
"net",
"sync",
"time",
"rt-multi-thread", # FIXME: remove once tokio::task::block_in_place calls are resolved
] }
tokio-util = { version = "0.7.15", features = ["codec"] }
futures = "0.3.31"
serde = { version = "1.0.219", features = ["derive"] }
rand = "0.9.1"
bytes = "1.10.1"
x25519-dalek = { version = "2.0.1", features = ["getrandom", "static_secrets"] }
aes-gcm = "0.10.3"
tracing = { version = "0.1.41", features = ["release_max_level_debug"] }
tracing-subscriber = { version = "0.3.19", features = ["env-filter"] }
tracing-logfmt = { version = "0.3.5", features = ["ansi_logs"] }
faster-hex = "0.10.0"
tokio-stream = { version = "0.1.17", features = ["sync"] }
left-right = "0.11.5"
ipnet = "2.11.0"
ip_network_table-deps-treebitmap = "0.5.0"
blake3 = "1.8.2"
etherparse = "0.18.0"
quinn = { version = "0.11.8", default-features = false, features = [
"runtime-tokio",
"rustls",
] }
rustls = { version = "0.23.29", default-features = false, features = ["ring"] }
rcgen = "0.14.2"
netdev = "0.36.0"
openssl = { version = "0.10.73", optional = true }
tokio-openssl = { version = "0.6.5", optional = true }
arc-swap = "1.7.1"
dashmap = { version = "6.1.0", features = ["inline"] }
ahash = "0.8.11"
axum = "0.8.4"
axum-extra = "0.10.1"
reqwest = "0.12.22"
redis = { version = "0.32.4", features = ["tokio-comp"] }
reed-solomon-erasure = "6.0.0"
[target.'cfg(target_os = "linux")'.dependencies]
rtnetlink = "0.17.0"
tokio-tun = "0.13.2"
nix = { version = "0.30.1", features = ["socket"] }
[target.'cfg(target_os = "macos")'.dependencies]
tun = { git = "https://github.com/LeeSmet/rust-tun", features = ["async"] }
libc = "0.2.174"
nix = { version = "0.29.0", features = ["net", "socket", "ioctl"] }
[target.'cfg(target_os = "windows")'.dependencies]
wintun = "0.5.1"
[target.'cfg(target_os = "android")'.dependencies]
tun = { git = "https://github.com/LeeSmet/rust-tun", features = ["async"] }
[target.'cfg(target_os = "ios")'.dependencies]
tun = { git = "https://github.com/LeeSmet/rust-tun", features = ["async"] }

321
mycelium/src/babel.rs Normal file
View File

@@ -0,0 +1,321 @@
//! This module contains babel related structs.
//!
//! We don't fully implement the babel spec, and items which are implemented might deviate to fit
//! our specific use case. For reference, the implementation is based on [this
//! RFC](https://datatracker.ietf.org/doc/html/rfc8966).
use std::io;
use bytes::{Buf, BufMut};
use tokio_util::codec::{Decoder, Encoder};
use tracing::trace;
pub use self::{
hello::Hello, ihu::Ihu, route_request::RouteRequest, seqno_request::SeqNoRequest,
update::Update,
};
pub use self::tlv::Tlv;
mod hello;
mod ihu;
mod route_request;
mod seqno_request;
mod tlv;
mod update;
/// Magic byte to identify babel protocol packet.
const BABEL_MAGIC: u8 = 42;
/// The version of the protocol we are currently using.
const BABEL_VERSION: u8 = 3;
/// Size of a babel header on the wire.
const HEADER_WIRE_SIZE: usize = 4;
/// TLV type for the [`Hello`] tlv
const TLV_TYPE_HELLO: u8 = 4;
/// TLV type for the [`Ihu`] tlv
const TLV_TYPE_IHU: u8 = 5;
/// TLV type for the [`Update`] tlv
const TLV_TYPE_UPDATE: u8 = 8;
/// TLV type for the [`RouteRequest`] tlv
const TLV_TYPE_ROUTE_REQUEST: u8 = 9;
/// TLV type for the [`SeqNoRequest`] tlv
const TLV_TYPE_SEQNO_REQUEST: u8 = 10;
/// Wildcard address, the value is empty (0 bytes length).
const AE_WILDCARD: u8 = 0;
/// IPv4 address, the value is _at most_ 4 bytes long.
const AE_IPV4: u8 = 1;
/// IPv6 address, the value is _at most_ 16 bytes long.
const AE_IPV6: u8 = 2;
/// Link-local IPv6 address, the value is 8 bytes long. This implies a `fe80::/64` prefix.
const AE_IPV6_LL: u8 = 3;
/// A codec which can send and receive whole babel packets on the wire.
#[derive(Debug, Clone)]
pub struct Codec {
header: Option<Header>,
}
impl Codec {
/// Create a new `BabelCodec`.
pub fn new() -> Self {
Self { header: None }
}
/// Resets the `BabelCodec` to its default state.
pub fn reset(&mut self) {
self.header = None;
}
}
/// The header for a babel packet. This follows the definition of the header [in the
/// RFC](https://datatracker.ietf.org/doc/html/rfc8966#name-packet-format). Since the header
/// contains only hard-coded fields and the length of an encoded body, there is no need for users
/// to manually construct this. In fact, it exists only to make our lives slightly easier in
/// reading/writing the header on the wire.
#[derive(Debug, Clone)]
struct Header {
magic: u8,
version: u8,
/// This is the length of the whole body following this header. Also excludes any possible
/// trailers.
body_length: u16,
}
impl Decoder for Codec {
type Item = Tlv;
type Error = io::Error;
fn decode(&mut self, src: &mut bytes::BytesMut) -> Result<Option<Self::Item>, Self::Error> {
// Read a header if we don't have one yet.
let header = if let Some(header) = self.header.take() {
trace!("Continue from stored header");
header
} else {
if src.remaining() < HEADER_WIRE_SIZE {
trace!("Insufficient bytes to read a babel header");
return Ok(None);
}
trace!("Read babel header");
Header {
magic: src.get_u8(),
version: src.get_u8(),
body_length: src.get_u16(),
}
};
if src.remaining() < header.body_length as usize {
trace!("Insufficient bytes to read babel body");
self.header = Some(header);
return Ok(None);
}
// Siltently ignore packets which don't have the correct values set, as defined in the
// spec. Note that we consume the amount of bytes indentified so we leave the parser in the
// correct state for the next packet.
if header.magic != BABEL_MAGIC || header.version != BABEL_VERSION {
trace!("Dropping babel packet with wrong magic or version");
src.advance(header.body_length as usize);
self.reset();
return Ok(None);
}
// at this point we have a whole body loaded in the buffer. We currently don't support sub
// TLV's
trace!("Read babel TLV body");
// TODO: Technically we need to loop here as we can have multiple TLVs.
// TLV header
let tlv_type = src.get_u8();
let body_len = src.get_u8();
// TLV payload
let tlv = match tlv_type {
TLV_TYPE_HELLO => Some(Hello::from_bytes(src).into()),
TLV_TYPE_IHU => Ihu::from_bytes(src, body_len).map(From::from),
TLV_TYPE_UPDATE => Update::from_bytes(src, body_len).map(From::from),
TLV_TYPE_ROUTE_REQUEST => RouteRequest::from_bytes(src, body_len).map(From::from),
TLV_TYPE_SEQNO_REQUEST => SeqNoRequest::from_bytes(src, body_len).map(From::from),
_ => {
// unrecoginized body type, silently drop
trace!("Dropping unrecognized tlv");
// We already read 2 bytes
src.advance(header.body_length as usize - 2);
self.reset();
return Ok(None);
}
};
Ok(tlv)
}
}
impl Encoder<Tlv> for Codec {
type Error = io::Error;
fn encode(&mut self, item: Tlv, dst: &mut bytes::BytesMut) -> Result<(), Self::Error> {
// Write header
dst.put_u8(BABEL_MAGIC);
dst.put_u8(BABEL_VERSION);
dst.put_u16(item.wire_size() as u16 + 2); // tlv payload + tlv header
// Write TLV's, TODO: currently only 1 TLV/body
// TLV header
match item {
Tlv::Hello(_) => dst.put_u8(TLV_TYPE_HELLO),
Tlv::Ihu(_) => dst.put_u8(TLV_TYPE_IHU),
Tlv::Update(_) => dst.put_u8(TLV_TYPE_UPDATE),
Tlv::RouteRequest(_) => dst.put_u8(TLV_TYPE_ROUTE_REQUEST),
Tlv::SeqNoRequest(_) => dst.put_u8(TLV_TYPE_SEQNO_REQUEST),
}
dst.put_u8(item.wire_size());
item.write_bytes(dst);
Ok(())
}
}
#[cfg(test)]
mod tests {
use std::{net::Ipv6Addr, time::Duration};
use futures::{SinkExt, StreamExt};
use tokio_util::codec::Framed;
use crate::subnet::Subnet;
#[tokio::test]
async fn codec_hello() {
let (tx, rx) = tokio::io::duplex(1024);
let mut sender = Framed::new(tx, super::Codec::new());
let mut receiver = Framed::new(rx, super::Codec::new());
let hello = super::Hello::new_unicast(15.into(), 400);
sender
.send(hello.clone().into())
.await
.expect("Send on a non-networked buffer can never fail; qed");
let recv_hello = receiver
.next()
.await
.expect("Buffer isn't closed so this is always `Some`; qed")
.expect("Can decode the previously encoded value");
assert_eq!(super::Tlv::from(hello), recv_hello);
}
#[tokio::test]
async fn codec_ihu() {
let (tx, rx) = tokio::io::duplex(1024);
let mut sender = Framed::new(tx, super::Codec::new());
let mut receiver = Framed::new(rx, super::Codec::new());
let ihu = super::Ihu::new(27.into(), 400, None);
sender
.send(ihu.clone().into())
.await
.expect("Send on a non-networked buffer can never fail; qed");
let recv_ihu = receiver
.next()
.await
.expect("Buffer isn't closed so this is always `Some`; qed")
.expect("Can decode the previously encoded value");
assert_eq!(super::Tlv::from(ihu), recv_ihu);
}
#[tokio::test]
async fn codec_update() {
let (tx, rx) = tokio::io::duplex(1024);
let mut sender = Framed::new(tx, super::Codec::new());
let mut receiver = Framed::new(rx, super::Codec::new());
let update = super::Update::new(
Duration::from_secs(400),
16.into(),
25.into(),
Subnet::new(Ipv6Addr::new(0x400, 1, 2, 3, 0, 0, 0, 0).into(), 64)
.expect("64 is a valid IPv6 prefix size; qed"),
[
1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19, 20, 21, 22, 23,
24, 25, 26, 27, 28, 29, 30, 31, 32, 33, 34, 35, 36, 37, 38, 39, 40,
]
.into(),
);
sender
.send(update.clone().into())
.await
.expect("Send on a non-networked buffer can never fail; qed");
println!("Sent update packet");
let recv_update = receiver
.next()
.await
.expect("Buffer isn't closed so this is always `Some`; qed")
.expect("Can decode the previously encoded value");
println!("Received update packet");
assert_eq!(super::Tlv::from(update), recv_update);
}
#[tokio::test]
async fn codec_seqno_request() {
let (tx, rx) = tokio::io::duplex(1024);
let mut sender = Framed::new(tx, super::Codec::new());
let mut receiver = Framed::new(rx, super::Codec::new());
let snr = super::SeqNoRequest::new(
16.into(),
[
1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19, 20, 21, 22, 23,
24, 25, 26, 27, 28, 29, 30, 31, 32, 33, 34, 35, 36, 37, 38, 39, 40,
]
.into(),
Subnet::new(Ipv6Addr::new(0x400, 1, 2, 3, 0, 0, 0, 0).into(), 64)
.expect("64 is a valid IPv6 prefix size; qed"),
);
sender
.send(snr.clone().into())
.await
.expect("Send on a non-networked buffer can never fail; qed");
let recv_update = receiver
.next()
.await
.expect("Buffer isn't closed so this is always `Some`; qed")
.expect("Can decode the previously encoded value");
assert_eq!(super::Tlv::from(snr), recv_update);
}
#[tokio::test]
async fn codec_route_request() {
let (tx, rx) = tokio::io::duplex(1024);
let mut sender = Framed::new(tx, super::Codec::new());
let mut receiver = Framed::new(rx, super::Codec::new());
let rr = super::RouteRequest::new(
Some(
Subnet::new(Ipv6Addr::new(0x400, 1, 2, 3, 0, 0, 0, 0).into(), 64)
.expect("64 is a valid IPv6 prefix size; qed"),
),
13,
);
sender
.send(rr.clone().into())
.await
.expect("Send on a non-networked buffer can never fail; qed");
let recv_update = receiver
.next()
.await
.expect("Buffer isn't closed so this is always `Some`; qed")
.expect("Can decode the previously encoded value");
assert_eq!(super::Tlv::from(rr), recv_update);
}
}

162
mycelium/src/babel/hello.rs Normal file
View File

@@ -0,0 +1,162 @@
//! The babel [Hello TLV](https://datatracker.ietf.org/doc/html/rfc8966#section-4.6.5).
use bytes::{Buf, BufMut};
use tracing::trace;
use crate::sequence_number::SeqNo;
/// Flag bit indicating a [`Hello`] is sent as unicast hello.
const HELLO_FLAG_UNICAST: u16 = 0x8000;
/// Mask to apply to [`Hello`] flags, leaving only valid flags.
const FLAG_MASK: u16 = 0b10000000_00000000;
/// Wire size of a [`Hello`] TLV without TLV header.
const HELLO_WIRE_SIZE: u8 = 6;
/// Hello TLV body as defined in https://datatracker.ietf.org/doc/html/rfc8966#section-4.6.5.
#[derive(Debug, Clone, PartialEq)]
pub struct Hello {
flags: u16,
seqno: SeqNo,
interval: u16,
}
impl Hello {
/// Create a new unicast hello packet.
pub fn new_unicast(seqno: SeqNo, interval: u16) -> Self {
Self {
flags: HELLO_FLAG_UNICAST,
seqno,
interval,
}
}
/// Calculates the size on the wire of this `Hello`.
pub fn wire_size(&self) -> u8 {
HELLO_WIRE_SIZE
}
/// Construct a `Hello` from wire bytes.
///
/// # Panics
///
/// This function will panic if there are insufficient bytes present in the provided buffer to
/// decode a complete `Hello`.
pub fn from_bytes(src: &mut bytes::BytesMut) -> Self {
let flags = src.get_u16() & FLAG_MASK;
let seqno = src.get_u16().into();
let interval = src.get_u16();
trace!("Read hello tlv body");
Self {
flags,
seqno,
interval,
}
}
/// Encode this `Hello` tlv as part of a packet.
pub fn write_bytes(&self, dst: &mut bytes::BytesMut) {
dst.put_u16(self.flags);
dst.put_u16(self.seqno.into());
dst.put_u16(self.interval);
}
}
#[cfg(test)]
mod tests {
use bytes::Buf;
#[test]
fn encoding() {
let mut buf = bytes::BytesMut::new();
let hello = super::Hello {
flags: 0,
seqno: 25.into(),
interval: 400,
};
hello.write_bytes(&mut buf);
assert_eq!(buf.len(), 6);
assert_eq!(buf[..6], [0, 0, 0, 25, 1, 144]);
let mut buf = bytes::BytesMut::new();
let hello = super::Hello {
flags: super::HELLO_FLAG_UNICAST,
seqno: 16.into(),
interval: 4000,
};
hello.write_bytes(&mut buf);
assert_eq!(buf.len(), 6);
assert_eq!(buf[..6], [128, 0, 0, 16, 15, 160]);
}
#[test]
fn decoding() {
let mut buf = bytes::BytesMut::from(&[0b10000000u8, 0b00000000, 0, 19, 2, 1][..]);
let hello = super::Hello {
flags: super::HELLO_FLAG_UNICAST,
seqno: 19.into(),
interval: 513,
};
assert_eq!(super::Hello::from_bytes(&mut buf), hello);
assert_eq!(buf.remaining(), 0);
let mut buf = bytes::BytesMut::from(&[0b00000000u8, 0b00000000, 1, 19, 200, 100][..]);
let hello = super::Hello {
flags: 0,
seqno: 275.into(),
interval: 51300,
};
assert_eq!(super::Hello::from_bytes(&mut buf), hello);
assert_eq!(buf.remaining(), 0);
}
#[test]
fn decode_ignores_invalid_flag_bits() {
let mut buf = bytes::BytesMut::from(&[0b10001001u8, 0b00000000, 0, 100, 1, 144][..]);
let hello = super::Hello {
flags: super::HELLO_FLAG_UNICAST,
seqno: 100.into(),
interval: 400,
};
assert_eq!(super::Hello::from_bytes(&mut buf), hello);
assert_eq!(buf.remaining(), 0);
let mut buf = bytes::BytesMut::from(&[0b00001001u8, 0b00000000, 0, 100, 1, 144][..]);
let hello = super::Hello {
flags: 0,
seqno: 100.into(),
interval: 400,
};
assert_eq!(super::Hello::from_bytes(&mut buf), hello);
assert_eq!(buf.remaining(), 0);
}
#[test]
fn roundtrip() {
let mut buf = bytes::BytesMut::new();
let hello_src = super::Hello::new_unicast(16.into(), 400);
hello_src.write_bytes(&mut buf);
let decoded = super::Hello::from_bytes(&mut buf);
assert_eq!(hello_src, decoded);
assert_eq!(buf.remaining(), 0);
}
}

246
mycelium/src/babel/ihu.rs Normal file
View File

@@ -0,0 +1,246 @@
//! The babel [IHU TLV](https://datatracker.ietf.org/doc/html/rfc8966#name-ihu).
use std::net::{IpAddr, Ipv4Addr, Ipv6Addr};
use bytes::{Buf, BufMut};
use tracing::trace;
use crate::metric::Metric;
use super::{AE_IPV4, AE_IPV6, AE_IPV6_LL, AE_WILDCARD};
/// Base wire size of an [`Ihu`] without variable length address encoding.
const IHU_BASE_WIRE_SIZE: u8 = 6;
/// IHU TLV body as defined in https://datatracker.ietf.org/doc/html/rfc8966#name-ihu.
#[derive(Debug, Clone, PartialEq)]
pub struct Ihu {
rx_cost: Metric,
interval: u16,
address: Option<IpAddr>,
}
impl Ihu {
/// Create a new `Ihu` to be transmitted.
pub fn new(rx_cost: Metric, interval: u16, address: Option<IpAddr>) -> Self {
// An interval of 0 is illegal according to the RFC, as this value is used by the receiver
// to calculate the hold time.
if interval == 0 {
panic!("Ihu interval MUST NOT be 0");
}
Self {
rx_cost,
interval,
address,
}
}
/// Calculates the size on the wire of this `Ihu`.
pub fn wire_size(&self) -> u8 {
IHU_BASE_WIRE_SIZE
+ match self.address {
None => 0,
Some(IpAddr::V4(_)) => 4,
// TODO: link local should be encoded differently
Some(IpAddr::V6(_)) => 16,
}
}
/// Construct a `Ihu` from wire bytes.
///
/// # Panics
///
/// This function will panic if there are insufficient bytes present in the provided buffer to
/// decode a complete `Ihu`.
pub fn from_bytes(src: &mut bytes::BytesMut, len: u8) -> Option<Self> {
let ae = src.get_u8();
// read and ignore reserved byte
let _ = src.get_u8();
let rx_cost = src.get_u16().into();
let interval = src.get_u16();
let address = match ae {
AE_WILDCARD => None,
AE_IPV4 => {
let mut raw_ip = [0; 4];
raw_ip.copy_from_slice(&src[..4]);
src.advance(4);
Some(Ipv4Addr::from(raw_ip).into())
}
AE_IPV6 => {
let mut raw_ip = [0; 16];
raw_ip.copy_from_slice(&src[..16]);
src.advance(16);
Some(Ipv6Addr::from(raw_ip).into())
}
AE_IPV6_LL => {
let mut raw_ip = [0; 16];
raw_ip[0] = 0xfe;
raw_ip[1] = 0x80;
raw_ip[8..].copy_from_slice(&src[..8]);
src.advance(8);
Some(Ipv6Addr::from(raw_ip).into())
}
_ => {
// Invalid AE type, skip reamining data and ignore
trace!("Invalid AE type in IHU TLV, drop TLV");
src.advance(len as usize - 6);
return None;
}
};
trace!("Read ihu tlv body");
Some(Self {
rx_cost,
interval,
address,
})
}
/// Encode this `Ihu` tlv as part of a packet.
pub fn write_bytes(&self, dst: &mut bytes::BytesMut) {
dst.put_u8(match self.address {
None => AE_WILDCARD,
Some(IpAddr::V4(_)) => AE_IPV4,
Some(IpAddr::V6(_)) => AE_IPV6,
});
// reserved byte, must be all 0
dst.put_u8(0);
dst.put_u16(self.rx_cost.into());
dst.put_u16(self.interval);
match self.address {
None => {}
Some(IpAddr::V4(ip)) => dst.put_slice(&ip.octets()),
Some(IpAddr::V6(ip)) => dst.put_slice(&ip.octets()),
}
}
}
#[cfg(test)]
mod tests {
use std::net::{Ipv4Addr, Ipv6Addr};
use bytes::Buf;
#[test]
fn encoding() {
let mut buf = bytes::BytesMut::new();
let ihu = super::Ihu {
rx_cost: 25.into(),
interval: 400,
address: Some(Ipv4Addr::new(1, 1, 1, 1).into()),
};
ihu.write_bytes(&mut buf);
assert_eq!(buf.len(), 10);
assert_eq!(buf[..10], [1, 0, 0, 25, 1, 144, 1, 1, 1, 1]);
let mut buf = bytes::BytesMut::new();
let ihu = super::Ihu {
rx_cost: 100.into(),
interval: 4000,
address: Some(Ipv6Addr::new(2, 0, 1234, 2345, 3456, 4567, 5678, 1).into()),
};
ihu.write_bytes(&mut buf);
assert_eq!(buf.len(), 22);
assert_eq!(
buf[..22],
[2, 0, 0, 100, 15, 160, 0, 2, 0, 0, 4, 210, 9, 41, 13, 128, 17, 215, 22, 46, 0, 1]
);
}
#[test]
fn decoding() {
let mut buf = bytes::BytesMut::from(&[0, 0, 0, 1, 1, 44][..]);
let ihu = super::Ihu {
rx_cost: 1.into(),
interval: 300,
address: None,
};
let buf_len = buf.len();
assert_eq!(super::Ihu::from_bytes(&mut buf, buf_len as u8), Some(ihu));
assert_eq!(buf.remaining(), 0);
let mut buf = bytes::BytesMut::from(&[1, 0, 0, 2, 0, 44, 3, 4, 5, 6][..]);
let ihu = super::Ihu {
rx_cost: 2.into(),
interval: 44,
address: Some(Ipv4Addr::new(3, 4, 5, 6).into()),
};
let buf_len = buf.len();
assert_eq!(super::Ihu::from_bytes(&mut buf, buf_len as u8), Some(ihu));
assert_eq!(buf.remaining(), 0);
let mut buf = bytes::BytesMut::from(
&[
2, 0, 0, 2, 0, 44, 4, 0, 0, 0, 0, 5, 0, 6, 7, 8, 9, 10, 11, 12, 13, 14,
][..],
);
let ihu = super::Ihu {
rx_cost: 2.into(),
interval: 44,
address: Some(Ipv6Addr::new(0x400, 0, 5, 6, 0x708, 0x90a, 0xb0c, 0xd0e).into()),
};
let buf_len = buf.len();
assert_eq!(super::Ihu::from_bytes(&mut buf, buf_len as u8), Some(ihu));
assert_eq!(buf.remaining(), 0);
let mut buf = bytes::BytesMut::from(&[3, 0, 1, 2, 0, 42, 7, 8, 9, 10, 11, 12, 13, 14][..]);
let ihu = super::Ihu {
rx_cost: 258.into(),
interval: 42,
address: Some(Ipv6Addr::new(0xfe80, 0, 0, 0, 0x708, 0x90a, 0xb0c, 0xd0e).into()),
};
let buf_len = buf.len();
assert_eq!(super::Ihu::from_bytes(&mut buf, buf_len as u8), Some(ihu));
assert_eq!(buf.remaining(), 0);
}
#[test]
fn decode_ignores_invalid_ae_encoding() {
// AE 4 as it is the first one which should be used in protocol extension, causing this
// test to fail if we forget to update something
let mut buf = bytes::BytesMut::from(
&[
4, 0, 0, 2, 0, 44, 2, 0, 0, 0, 0, 5, 0, 6, 7, 8, 9, 10, 11, 12, 13, 14,
][..],
);
let buf_len = buf.len();
assert_eq!(super::Ihu::from_bytes(&mut buf, buf_len as u8), None);
// Decode function should still consume the required amount of bytes to leave parser in a
// good state (assuming the length in the tlv preamble is good).
assert_eq!(buf.remaining(), 0);
}
#[test]
fn roundtrip() {
let mut buf = bytes::BytesMut::new();
let hello_src = super::Ihu::new(
16.into(),
400,
Some(Ipv6Addr::new(156, 5646, 4164, 1236, 872, 960, 10, 844).into()),
);
hello_src.write_bytes(&mut buf);
let buf_len = buf.len();
let decoded = super::Ihu::from_bytes(&mut buf, buf_len as u8);
assert_eq!(Some(hello_src), decoded);
assert_eq!(buf.remaining(), 0);
}
}

View File

@@ -0,0 +1,301 @@
use std::net::{IpAddr, Ipv4Addr, Ipv6Addr};
use bytes::{Buf, BufMut};
use tracing::trace;
use crate::subnet::Subnet;
use super::{AE_IPV4, AE_IPV6, AE_IPV6_LL, AE_WILDCARD};
/// Base wire size of a [`RouteRequest`] without variable length address encoding.
const ROUTE_REQUEST_BASE_WIRE_SIZE: u8 = 3;
/// Seqno request TLV body as defined in https://datatracker.ietf.org/doc/html/rfc8966#name-route-request
#[derive(Debug, Clone, PartialEq)]
pub struct RouteRequest {
/// The prefix being requested
prefix: Option<Subnet>,
/// The requests' generation
generation: u8,
}
impl RouteRequest {
/// Creates a new `RouteRequest` for the given [`prefix`]. If no [`prefix`] is given, a full
/// route table dumb in requested.
///
/// [`prefix`]: Subnet
pub fn new(prefix: Option<Subnet>, generation: u8) -> Self {
Self { prefix, generation }
}
/// Return the [`prefix`](Subnet) associated with this `RouteRequest`.
pub fn prefix(&self) -> Option<Subnet> {
self.prefix
}
/// Return the generation of the `RouteRequest`, which is the amount of times it has been
/// forwarded already.
pub fn generation(&self) -> u8 {
self.generation
}
/// Increment the generation of the `RouteRequest`.
pub fn inc_generation(&mut self) {
self.generation += 1
}
/// Calculates the size on the wire of this `RouteRequest`.
pub fn wire_size(&self) -> u8 {
ROUTE_REQUEST_BASE_WIRE_SIZE
+ (if let Some(prefix) = self.prefix {
prefix.prefix_len().div_ceil(8)
} else {
0
})
}
/// Construct a `RouteRequest` from wire bytes.
///
/// # Panics
///
/// This function will panic if there are insufficient bytes present in the provided buffer to
/// decode a complete `RouteRequest`.
pub fn from_bytes(src: &mut bytes::BytesMut, len: u8) -> Option<Self> {
let generation = src.get_u8();
let ae = src.get_u8();
let plen = src.get_u8();
let prefix_size = plen.div_ceil(8) as usize;
let prefix_ip = match ae {
AE_WILDCARD => None,
AE_IPV4 => {
if plen > 32 {
return None;
}
let mut raw_ip = [0; 4];
raw_ip[..prefix_size].copy_from_slice(&src[..prefix_size]);
src.advance(prefix_size);
Some(Ipv4Addr::from(raw_ip).into())
}
AE_IPV6 => {
if plen > 128 {
return None;
}
let mut raw_ip = [0; 16];
raw_ip[..prefix_size].copy_from_slice(&src[..prefix_size]);
src.advance(prefix_size);
Some(Ipv6Addr::from(raw_ip).into())
}
AE_IPV6_LL => {
if plen != 64 {
return None;
}
let mut raw_ip = [0; 16];
raw_ip[0] = 0xfe;
raw_ip[1] = 0x80;
raw_ip[8..].copy_from_slice(&src[..8]);
src.advance(8);
Some(Ipv6Addr::from(raw_ip).into())
}
_ => {
// Invalid AE type, skip reamining data and ignore
trace!("Invalid AE type in route_request packet, drop packet");
src.advance(len as usize - 3);
return None;
}
};
let prefix = prefix_ip.and_then(|prefix| Subnet::new(prefix, plen).ok());
trace!("Read route_request tlv body");
Some(RouteRequest { prefix, generation })
}
/// Encode this `RouteRequest` tlv as part of a packet.
pub fn write_bytes(&self, dst: &mut bytes::BytesMut) {
dst.put_u8(self.generation);
if let Some(prefix) = self.prefix {
dst.put_u8(match prefix.address() {
IpAddr::V4(_) => AE_IPV4,
IpAddr::V6(_) => AE_IPV6,
});
dst.put_u8(prefix.prefix_len());
let prefix_len = prefix.prefix_len().div_ceil(8) as usize;
match prefix.address() {
IpAddr::V4(ip) => dst.put_slice(&ip.octets()[..prefix_len]),
IpAddr::V6(ip) => dst.put_slice(&ip.octets()[..prefix_len]),
}
} else {
dst.put_u8(AE_WILDCARD);
// Prefix len MUST be 0 for wildcard requests
dst.put_u8(0);
}
}
}
#[cfg(test)]
mod tests {
use std::net::{Ipv4Addr, Ipv6Addr};
use bytes::Buf;
use crate::subnet::Subnet;
#[test]
fn encoding() {
let mut buf = bytes::BytesMut::new();
let rr = super::RouteRequest {
prefix: Some(
Subnet::new(Ipv6Addr::new(512, 25, 26, 27, 28, 0, 0, 29).into(), 64)
.expect("64 is a valid IPv6 prefix size; qed"),
),
generation: 2,
};
rr.write_bytes(&mut buf);
assert_eq!(buf.len(), 11);
assert_eq!(buf[..11], [2, 2, 64, 2, 0, 0, 25, 0, 26, 0, 27]);
let mut buf = bytes::BytesMut::new();
let rr = super::RouteRequest {
prefix: Some(
Subnet::new(Ipv4Addr::new(10, 101, 4, 1).into(), 32)
.expect("32 is a valid IPv4 prefix size; qed"),
),
generation: 3,
};
rr.write_bytes(&mut buf);
assert_eq!(buf.len(), 7);
assert_eq!(buf[..7], [3, 1, 32, 10, 101, 4, 1]);
let mut buf = bytes::BytesMut::new();
let rr = super::RouteRequest {
prefix: None,
generation: 0,
};
rr.write_bytes(&mut buf);
assert_eq!(buf.len(), 3);
assert_eq!(buf[..3], [0, 0, 0]);
}
#[test]
fn decoding() {
let mut buf = bytes::BytesMut::from(&[12, 0, 0][..]);
let rr = super::RouteRequest {
prefix: None,
generation: 12,
};
let buf_len = buf.len();
assert_eq!(
super::RouteRequest::from_bytes(&mut buf, buf_len as u8),
Some(rr)
);
assert_eq!(buf.remaining(), 0);
let mut buf = bytes::BytesMut::from(&[24, 1, 24, 10, 15, 19][..]);
let rr = super::RouteRequest {
prefix: Some(
Subnet::new(Ipv4Addr::new(10, 15, 19, 0).into(), 24)
.expect("24 is a valid IPv4 prefix size; qed"),
),
generation: 24,
};
let buf_len = buf.len();
assert_eq!(
super::RouteRequest::from_bytes(&mut buf, buf_len as u8),
Some(rr)
);
assert_eq!(buf.remaining(), 0);
let mut buf = bytes::BytesMut::from(&[7, 2, 64, 0, 10, 0, 20, 0, 30, 0, 40][..]);
let rr = super::RouteRequest {
prefix: Some(
Subnet::new(Ipv6Addr::new(10, 20, 30, 40, 0, 0, 0, 0).into(), 64)
.expect("64 is a valid IPv6 prefix size; qed"),
),
generation: 7,
};
let buf_len = buf.len();
assert_eq!(
super::RouteRequest::from_bytes(&mut buf, buf_len as u8),
Some(rr)
);
assert_eq!(buf.remaining(), 0);
let mut buf = bytes::BytesMut::from(&[4, 3, 64, 0, 10, 0, 20, 0, 30, 0, 40][..]);
let rr = super::RouteRequest {
prefix: Some(
Subnet::new(Ipv6Addr::new(0xfe80, 0, 0, 0, 10, 20, 30, 40).into(), 64)
.expect("64 is a valid IPv6 prefix size; qed"),
),
generation: 4,
};
let buf_len = buf.len();
assert_eq!(
super::RouteRequest::from_bytes(&mut buf, buf_len as u8),
Some(rr)
);
assert_eq!(buf.remaining(), 0);
}
#[test]
fn decode_ignores_invalid_ae_encoding() {
// AE 4 as it is the first one which should be used in protocol extension, causing this
// test to fail if we forget to update something
let mut buf = bytes::BytesMut::from(
&[
0, 4, 64, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19, 20, 21,
][..],
);
let buf_len = buf.len();
assert_eq!(
super::RouteRequest::from_bytes(&mut buf, buf_len as u8),
None
);
// Decode function should still consume the required amount of bytes to leave parser in a
// good state (assuming the length in the tlv preamble is good).
assert_eq!(buf.remaining(), 0);
}
#[test]
fn roundtrip() {
let mut buf = bytes::BytesMut::new();
let seqno_src = super::RouteRequest::new(
Some(
Subnet::new(
Ipv6Addr::new(0x21f, 0x4025, 0xabcd, 0xdead, 0, 0, 0, 0).into(),
64,
)
.expect("64 is a valid IPv6 prefix size; qed"),
),
27,
);
seqno_src.write_bytes(&mut buf);
let buf_len = buf.len();
let decoded = super::RouteRequest::from_bytes(&mut buf, buf_len as u8);
assert_eq!(Some(seqno_src), decoded);
assert_eq!(buf.remaining(), 0);
}
}

View File

@@ -0,0 +1,356 @@
use std::{
net::{IpAddr, Ipv4Addr, Ipv6Addr},
num::NonZeroU8,
};
use bytes::{Buf, BufMut};
use tracing::{debug, trace};
use crate::{router_id::RouterId, sequence_number::SeqNo, subnet::Subnet};
use super::{AE_IPV4, AE_IPV6, AE_IPV6_LL, AE_WILDCARD};
/// The default HOP COUNT value used in new SeqNo requests, as per https://datatracker.ietf.org/doc/html/rfc8966#section-3.8.2.1
// SAFETY: value is not zero.
const DEFAULT_HOP_COUNT: NonZeroU8 = NonZeroU8::new(64).unwrap();
/// Base wire size of a [`SeqNoRequest`] without variable length address encoding.
const SEQNO_REQUEST_BASE_WIRE_SIZE: u8 = 6 + RouterId::BYTE_SIZE as u8;
/// Seqno request TLV body as defined in https://datatracker.ietf.org/doc/html/rfc8966#name-seqno-request
#[derive(Debug, Clone, PartialEq)]
pub struct SeqNoRequest {
/// The sequence number that is being requested.
seqno: SeqNo,
/// The maximum number of times this TLV may be forwarded, plus 1.
hop_count: NonZeroU8,
/// The router id that is being requested.
router_id: RouterId,
/// The prefix being requested
prefix: Subnet,
}
impl SeqNoRequest {
/// Create a new `SeqNoRequest` for the given [prefix](Subnet) advertised by the [`RouterId`],
/// with the required new [`SeqNo`].
pub fn new(seqno: SeqNo, router_id: RouterId, prefix: Subnet) -> SeqNoRequest {
Self {
seqno,
hop_count: DEFAULT_HOP_COUNT,
router_id,
prefix,
}
}
/// Return the [`prefix`](Subnet) associated with this `SeqNoRequest`.
pub fn prefix(&self) -> Subnet {
self.prefix
}
/// Return the [`RouterId`] associated with this `SeqNoRequest`.
pub fn router_id(&self) -> RouterId {
self.router_id
}
/// Return the requested [`SeqNo`] associated with this `SeqNoRequest`.
pub fn seqno(&self) -> SeqNo {
self.seqno
}
/// Get the hop count for this `SeqNoRequest`.
pub fn hop_count(&self) -> u8 {
self.hop_count.into()
}
/// Decrement the hop count for this `SeqNoRequest`.
///
/// # Panics
///
/// This function will panic if the hop count before calling this function is 1, as that will
/// result in a hop count of 0, which is illegal for a `SeqNoRequest`. It is up to the caller
/// to ensure this condition holds.
pub fn decrement_hop_count(&mut self) {
// SAFETY: The panic from this expect is documented in the function signature.
self.hop_count = NonZeroU8::new(self.hop_count.get() - 1)
.expect("Decrementing a hop count of 1 is not allowed");
}
/// Calculates the size on the wire of this `Update`.
pub fn wire_size(&self) -> u8 {
SEQNO_REQUEST_BASE_WIRE_SIZE + self.prefix.prefix_len().div_ceil(8)
// TODO: Wildcard should be encoded differently
}
/// Construct a `SeqNoRequest` from wire bytes.
///
/// # Panics
///
/// This function will panic if there are insufficient bytes present in the provided buffer to
/// decode a complete `SeqNoRequest`.
pub fn from_bytes(src: &mut bytes::BytesMut, len: u8) -> Option<Self> {
let ae = src.get_u8();
let plen = src.get_u8();
let seqno = src.get_u16().into();
let hop_count = src.get_u8();
// Read "reserved" value, we assume this is 0
let _ = src.get_u8();
let mut router_id_bytes = [0u8; RouterId::BYTE_SIZE];
router_id_bytes.copy_from_slice(&src[..RouterId::BYTE_SIZE]);
src.advance(RouterId::BYTE_SIZE);
let router_id = RouterId::from(router_id_bytes);
let prefix_size = plen.div_ceil(8) as usize;
let prefix = match ae {
AE_WILDCARD => {
if plen != 0 {
return None;
}
// TODO: this is a temporary placeholder until we figure out how to handle this
Ipv6Addr::new(0, 0, 0, 0, 0, 0, 0, 0).into()
}
AE_IPV4 => {
if plen > 32 {
return None;
}
let mut raw_ip = [0; 4];
raw_ip[..prefix_size].copy_from_slice(&src[..prefix_size]);
src.advance(prefix_size);
Ipv4Addr::from(raw_ip).into()
}
AE_IPV6 => {
if plen > 128 {
return None;
}
let mut raw_ip = [0; 16];
raw_ip[..prefix_size].copy_from_slice(&src[..prefix_size]);
src.advance(prefix_size);
Ipv6Addr::from(raw_ip).into()
}
AE_IPV6_LL => {
if plen != 64 {
return None;
}
let mut raw_ip = [0; 16];
raw_ip[0] = 0xfe;
raw_ip[1] = 0x80;
raw_ip[8..].copy_from_slice(&src[..8]);
src.advance(8);
Ipv6Addr::from(raw_ip).into()
}
_ => {
// Invalid AE type, skip reamining data and ignore
trace!("Invalid AE type in seqno_request packet, drop packet");
src.advance(len as usize - 46);
return None;
}
};
let prefix = Subnet::new(prefix, plen).ok()?;
trace!("Read seqno_request tlv body");
// Make sure hop_count is valid
let hop_count = if let Some(hc) = NonZeroU8::new(hop_count) {
hc
} else {
debug!("Dropping seqno_request as hop_count field is set to 0");
return None;
};
Some(SeqNoRequest {
seqno,
hop_count,
router_id,
prefix,
})
}
/// Encode this `SeqNoRequest` tlv as part of a packet.
pub fn write_bytes(&self, dst: &mut bytes::BytesMut) {
dst.put_u8(match self.prefix.address() {
IpAddr::V4(_) => AE_IPV4,
IpAddr::V6(_) => AE_IPV6,
});
dst.put_u8(self.prefix.prefix_len());
dst.put_u16(self.seqno.into());
dst.put_u8(self.hop_count.into());
// Write "reserved" value.
dst.put_u8(0);
dst.put_slice(&self.router_id.as_bytes()[..]);
let prefix_len = self.prefix.prefix_len().div_ceil(8) as usize;
match self.prefix.address() {
IpAddr::V4(ip) => dst.put_slice(&ip.octets()[..prefix_len]),
IpAddr::V6(ip) => dst.put_slice(&ip.octets()[..prefix_len]),
}
}
}
#[cfg(test)]
mod tests {
use std::{
net::{Ipv4Addr, Ipv6Addr},
num::NonZeroU8,
};
use crate::{router_id::RouterId, subnet::Subnet};
use bytes::Buf;
#[test]
fn encoding() {
let mut buf = bytes::BytesMut::new();
let snr = super::SeqNoRequest {
seqno: 17.into(),
hop_count: NonZeroU8::new(64).unwrap(),
prefix: Subnet::new(Ipv6Addr::new(512, 25, 26, 27, 28, 0, 0, 29).into(), 64)
.expect("64 is a valid IPv6 prefix size; qed"),
router_id: RouterId::from([1u8; RouterId::BYTE_SIZE]),
};
snr.write_bytes(&mut buf);
assert_eq!(buf.len(), 54);
assert_eq!(
buf[..54],
[
2, 64, 0, 17, 64, 0, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1,
1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 2, 0, 0, 25, 0, 26, 0, 27,
]
);
let mut buf = bytes::BytesMut::new();
let snr = super::SeqNoRequest {
seqno: 170.into(),
hop_count: NonZeroU8::new(111).unwrap(),
prefix: Subnet::new(Ipv4Addr::new(10, 101, 4, 1).into(), 32)
.expect("32 is a valid IPv4 prefix size; qed"),
router_id: RouterId::from([2u8; RouterId::BYTE_SIZE]),
};
snr.write_bytes(&mut buf);
assert_eq!(buf.len(), 50);
assert_eq!(
buf[..50],
[
1, 32, 0, 170, 111, 0, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2,
2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 10, 101, 4, 1,
]
);
}
#[test]
fn decoding() {
let mut buf = bytes::BytesMut::from(
&[
0, 0, 0, 0, 1, 0, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3,
3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3,
][..],
);
let snr = super::SeqNoRequest {
hop_count: NonZeroU8::new(1).unwrap(),
seqno: 0.into(),
prefix: Subnet::new(Ipv6Addr::new(0, 0, 0, 0, 0, 0, 0, 0).into(), 0)
.expect("0 is a valid IPv6 prefix size; qed"),
router_id: RouterId::from([3u8; RouterId::BYTE_SIZE]),
};
let buf_len = buf.len();
assert_eq!(
super::SeqNoRequest::from_bytes(&mut buf, buf_len as u8),
Some(snr)
);
assert_eq!(buf.remaining(), 0);
let mut buf = bytes::BytesMut::from(
&[
3, 64, 0, 42, 232, 0, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4,
4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 0, 10, 0, 20, 0, 30, 0,
40,
][..],
);
let snr = super::SeqNoRequest {
seqno: 42.into(),
hop_count: NonZeroU8::new(232).unwrap(),
prefix: Subnet::new(Ipv6Addr::new(0xfe80, 0, 0, 0, 10, 20, 30, 40).into(), 64)
.expect("92 is a valid IPv6 prefix size; qed"),
router_id: RouterId::from([4u8; RouterId::BYTE_SIZE]),
};
let buf_len = buf.len();
assert_eq!(
super::SeqNoRequest::from_bytes(&mut buf, buf_len as u8),
Some(snr)
);
assert_eq!(buf.remaining(), 0);
}
#[test]
fn decode_ignores_invalid_ae_encoding() {
// AE 4 as it is the first one which should be used in protocol extension, causing this
// test to fail if we forget to update something
let mut buf = bytes::BytesMut::from(
&[
4, 64, 0, 0, 44, 0, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5,
5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 6, 7, 8, 9, 10, 11, 12,
13, 14, 15, 16, 17, 18, 19, 20, 21,
][..],
);
let buf_len = buf.len();
assert_eq!(
super::SeqNoRequest::from_bytes(&mut buf, buf_len as u8),
None
);
// Decode function should still consume the required amount of bytes to leave parser in a
// good state (assuming the length in the tlv preamble is good).
assert_eq!(buf.remaining(), 0);
}
#[test]
fn decode_ignores_invalid_hop_count() {
// Set all flag bits, only allowed bits should be set on the decoded value
let mut buf = bytes::BytesMut::from(
&[
3, 64, 92, 0, 0, 0, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4,
4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 0, 10, 0, 20, 0, 30, 0,
40,
][..],
);
let buf_len = buf.len();
assert_eq!(
super::SeqNoRequest::from_bytes(&mut buf, buf_len as u8),
None
);
assert_eq!(buf.remaining(), 0);
}
#[test]
fn roundtrip() {
let mut buf = bytes::BytesMut::new();
let seqno_src = super::SeqNoRequest::new(
64.into(),
RouterId::from([6; RouterId::BYTE_SIZE]),
Subnet::new(
Ipv6Addr::new(0x21f, 0x4025, 0xabcd, 0xdead, 0, 0, 0, 0).into(),
64,
)
.expect("64 is a valid IPv6 prefix size; qed"),
);
seqno_src.write_bytes(&mut buf);
let buf_len = buf.len();
let decoded = super::SeqNoRequest::from_bytes(&mut buf, buf_len as u8);
assert_eq!(Some(seqno_src), decoded);
assert_eq!(buf.remaining(), 0);
}
}

72
mycelium/src/babel/tlv.rs Normal file
View File

@@ -0,0 +1,72 @@
pub use super::{hello::Hello, ihu::Ihu, update::Update};
use super::{route_request::RouteRequest, SeqNoRequest};
/// A single `Tlv` in a babel packet body.
#[derive(Debug, Clone, PartialEq)]
pub enum Tlv {
/// Hello Tlv type.
Hello(Hello),
/// Ihu Tlv type.
Ihu(Ihu),
/// Update Tlv type.
Update(Update),
/// RouteRequest Tlv type.
RouteRequest(RouteRequest),
/// SeqNoRequest Tlv type
SeqNoRequest(SeqNoRequest),
}
impl Tlv {
/// Calculate the size on the wire for this `Tlv`. This DOES NOT included the TLV header size
/// (2 bytes).
pub fn wire_size(&self) -> u8 {
match self {
Self::Hello(hello) => hello.wire_size(),
Self::Ihu(ihu) => ihu.wire_size(),
Self::Update(update) => update.wire_size(),
Self::RouteRequest(route_request) => route_request.wire_size(),
Self::SeqNoRequest(seqno_request) => seqno_request.wire_size(),
}
}
/// Encode this `Tlv` as part of a packet.
pub fn write_bytes(&self, dst: &mut bytes::BytesMut) {
match self {
Self::Hello(hello) => hello.write_bytes(dst),
Self::Ihu(ihu) => ihu.write_bytes(dst),
Self::Update(update) => update.write_bytes(dst),
Self::RouteRequest(route_request) => route_request.write_bytes(dst),
Self::SeqNoRequest(seqno_request) => seqno_request.write_bytes(dst),
}
}
}
impl From<SeqNoRequest> for Tlv {
fn from(v: SeqNoRequest) -> Self {
Self::SeqNoRequest(v)
}
}
impl From<RouteRequest> for Tlv {
fn from(v: RouteRequest) -> Self {
Self::RouteRequest(v)
}
}
impl From<Update> for Tlv {
fn from(v: Update) -> Self {
Self::Update(v)
}
}
impl From<Ihu> for Tlv {
fn from(v: Ihu) -> Self {
Self::Ihu(v)
}
}
impl From<Hello> for Tlv {
fn from(v: Hello) -> Self {
Self::Hello(v)
}
}

View File

@@ -0,0 +1,385 @@
//! The babel [Update TLV](https://datatracker.ietf.org/doc/html/rfc8966#name-update).
use std::{
net::{IpAddr, Ipv4Addr, Ipv6Addr},
time::Duration,
};
use bytes::{Buf, BufMut};
use tracing::trace;
use crate::{metric::Metric, router_id::RouterId, sequence_number::SeqNo, subnet::Subnet};
use super::{AE_IPV4, AE_IPV6, AE_IPV6_LL, AE_WILDCARD};
/// Flag bit indicating an [`Update`] TLV establishes a new default prefix.
#[allow(dead_code)]
const UPDATE_FLAG_PREFIX: u8 = 0x80;
/// Flag bit indicating an [`Update`] TLV establishes a new default router-id.
#[allow(dead_code)]
const UPDATE_FLAG_ROUTER_ID: u8 = 0x40;
/// Mask to apply to [`Update`] flags, leaving only valid flags.
const FLAG_MASK: u8 = 0b1100_0000;
/// Base wire size of an [`Update`] without variable length address encoding.
const UPDATE_BASE_WIRE_SIZE: u8 = 10 + RouterId::BYTE_SIZE as u8;
/// Update TLV body as defined in https://datatracker.ietf.org/doc/html/rfc8966#name-update.
#[derive(Debug, Clone, PartialEq)]
pub struct Update {
/// Flags set in the TLV.
flags: u8,
/// Upper bound in centiseconds after which a new `Update` is sent. Must not be 0.
interval: u16,
/// Senders sequence number.
seqno: SeqNo,
/// Senders metric for this route.
metric: Metric,
/// The [`Subnet`] contained in this update. An update packet itself can contain any allowed
/// subnet.
subnet: Subnet,
/// Router id of the sender. Importantly this is not part of the update itself, though we do
/// transmit it for now as such.
router_id: RouterId,
}
impl Update {
/// Create a new `Update`.
pub fn new(
interval: Duration,
seqno: SeqNo,
metric: Metric,
subnet: Subnet,
router_id: RouterId,
) -> Self {
let interval_centiseconds = (interval.as_millis() / 10) as u16;
Self {
// No flags used for now
flags: 0,
interval: interval_centiseconds,
seqno,
metric,
subnet,
router_id,
}
}
/// Returns the [`SeqNo`] of the sender of this `Update`.
pub fn seqno(&self) -> SeqNo {
self.seqno
}
/// Return the [`Metric`] of the sender for this route in the `Update`.
pub fn metric(&self) -> Metric {
self.metric
}
/// Return the [`Subnet`] in this `Update.`
pub fn subnet(&self) -> Subnet {
self.subnet
}
/// Return the [`router-id`](PublicKey) of the router who advertised this [`Prefix`](IpAddr).
pub fn router_id(&self) -> RouterId {
self.router_id
}
/// Calculates the size on the wire of this `Update`.
pub fn wire_size(&self) -> u8 {
let address_bytes = self.subnet.prefix_len().div_ceil(8);
UPDATE_BASE_WIRE_SIZE + address_bytes
}
/// Get the time until a new `Update` for the [`Subnet`] is received at the latest.
pub fn interval(&self) -> Duration {
// Interval is expressed as centiseconds on the wire.
Duration::from_millis(self.interval as u64 * 10)
}
/// Construct an `Update` from wire bytes.
///
/// # Panics
///
/// This function will panic if there are insufficient bytes present in the provided buffer to
/// decode a complete `Update`.
pub fn from_bytes(src: &mut bytes::BytesMut, len: u8) -> Option<Self> {
let ae = src.get_u8();
let flags = src.get_u8() & FLAG_MASK;
let plen = src.get_u8();
// Read "omitted" value, we assume this is 0
let _ = src.get_u8();
let interval = src.get_u16();
let seqno = src.get_u16().into();
let metric = src.get_u16().into();
let prefix_size = plen.div_ceil(8) as usize;
let prefix = match ae {
AE_WILDCARD => {
if prefix_size != 0 {
return None;
}
// TODO: this is a temporary placeholder until we figure out how to handle this
Ipv6Addr::new(0, 0, 0, 0, 0, 0, 0, 0).into()
}
AE_IPV4 => {
if plen > 32 {
return None;
}
let mut raw_ip = [0; 4];
raw_ip[..prefix_size].copy_from_slice(&src[..prefix_size]);
src.advance(prefix_size);
Ipv4Addr::from(raw_ip).into()
}
AE_IPV6 => {
if plen > 128 {
return None;
}
let mut raw_ip = [0; 16];
raw_ip[..prefix_size].copy_from_slice(&src[..prefix_size]);
src.advance(prefix_size);
Ipv6Addr::from(raw_ip).into()
}
AE_IPV6_LL => {
if plen != 64 {
return None;
}
let mut raw_ip = [0; 16];
raw_ip[0] = 0xfe;
raw_ip[1] = 0x80;
raw_ip[8..].copy_from_slice(&src[..8]);
src.advance(8);
Ipv6Addr::from(raw_ip).into()
}
_ => {
// Invalid AE type, skip reamining data and ignore
trace!("Invalid AE type in update packet, drop packet");
src.advance(len as usize - 10);
return None;
}
};
let subnet = Subnet::new(prefix, plen).ok()?;
let mut router_id_bytes = [0u8; RouterId::BYTE_SIZE];
router_id_bytes.copy_from_slice(&src[..RouterId::BYTE_SIZE]);
src.advance(RouterId::BYTE_SIZE);
let router_id = RouterId::from(router_id_bytes);
trace!("Read update tlv body");
Some(Update {
flags,
interval,
seqno,
metric,
subnet,
router_id,
})
}
/// Encode this `Update` tlv as part of a packet.
pub fn write_bytes(&self, dst: &mut bytes::BytesMut) {
dst.put_u8(match self.subnet.address() {
IpAddr::V4(_) => AE_IPV4,
IpAddr::V6(_) => AE_IPV6,
});
dst.put_u8(self.flags);
dst.put_u8(self.subnet.prefix_len());
// Write "omitted" value, currently not used in our encoding scheme.
dst.put_u8(0);
dst.put_u16(self.interval);
dst.put_u16(self.seqno.into());
dst.put_u16(self.metric.into());
let prefix_len = self.subnet.prefix_len().div_ceil(8) as usize;
match self.subnet.address() {
IpAddr::V4(ip) => dst.put_slice(&ip.octets()[..prefix_len]),
IpAddr::V6(ip) => dst.put_slice(&ip.octets()[..prefix_len]),
}
dst.put_slice(&self.router_id.as_bytes()[..])
}
}
#[cfg(test)]
mod tests {
use std::{
net::{Ipv4Addr, Ipv6Addr},
time::Duration,
};
use crate::{router_id::RouterId, subnet::Subnet};
use bytes::Buf;
#[test]
fn encoding() {
let mut buf = bytes::BytesMut::new();
let ihu = super::Update {
flags: 0b1100_0000,
interval: 400,
seqno: 17.into(),
metric: 25.into(),
subnet: Subnet::new(Ipv6Addr::new(512, 25, 26, 27, 28, 0, 0, 29).into(), 64)
.expect("64 is a valid IPv6 prefix size; qed"),
router_id: RouterId::from([1u8; RouterId::BYTE_SIZE]),
};
ihu.write_bytes(&mut buf);
assert_eq!(buf.len(), 58);
assert_eq!(
buf[..58],
[
2, 192, 64, 0, 1, 144, 0, 17, 0, 25, 2, 0, 0, 25, 0, 26, 0, 27, 1, 1, 1, 1, 1, 1,
1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1,
1, 1, 1, 1, 1, 1
]
);
let mut buf = bytes::BytesMut::new();
let ihu = super::Update {
flags: 0b0000_0000,
interval: 600,
seqno: 170.into(),
metric: 256.into(),
subnet: Subnet::new(Ipv4Addr::new(10, 101, 4, 1).into(), 23)
.expect("23 is a valid IPv4 prefix size; qed"),
router_id: RouterId::from([2u8; RouterId::BYTE_SIZE]),
};
ihu.write_bytes(&mut buf);
assert_eq!(buf.len(), 53);
assert_eq!(
buf[..53],
[
1, 0, 23, 0, 2, 88, 0, 170, 1, 0, 10, 101, 4, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2,
2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2
]
);
}
#[test]
fn decoding() {
let mut buf = bytes::BytesMut::from(
&[
0, 64, 0, 0, 0, 100, 0, 70, 2, 0, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3,
3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3,
][..],
);
let ihu = super::Update {
flags: 0b0100_0000,
interval: 100,
seqno: 70.into(),
metric: 512.into(),
subnet: Subnet::new(Ipv6Addr::new(0, 0, 0, 0, 0, 0, 0, 0).into(), 0)
.expect("0 is a valid IPv6 prefix size; qed"),
router_id: RouterId::from([3u8; RouterId::BYTE_SIZE]),
};
let buf_len = buf.len();
assert_eq!(
super::Update::from_bytes(&mut buf, buf_len as u8),
Some(ihu)
);
assert_eq!(buf.remaining(), 0);
let mut buf = bytes::BytesMut::from(
&[
3, 0, 64, 0, 3, 232, 0, 42, 3, 1, 0, 10, 0, 20, 0, 30, 0, 40, 4, 4, 4, 4, 4, 4, 4,
4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4,
4, 4, 4, 4, 4,
][..],
);
let ihu = super::Update {
flags: 0b0000_0000,
interval: 1000,
seqno: 42.into(),
metric: 769.into(),
subnet: Subnet::new(Ipv6Addr::new(0xfe80, 0, 0, 0, 10, 20, 30, 40).into(), 64)
.expect("92 is a valid IPv6 prefix size; qed"),
router_id: RouterId::from([4u8; RouterId::BYTE_SIZE]),
};
let buf_len = buf.len();
assert_eq!(
super::Update::from_bytes(&mut buf, buf_len as u8),
Some(ihu)
);
assert_eq!(buf.remaining(), 0);
}
#[test]
fn decode_ignores_invalid_ae_encoding() {
// AE 4 as it is the first one which should be used in protocol extension, causing this
// test to fail if we forget to update something
let mut buf = bytes::BytesMut::from(
&[
4, 0, 64, 0, 0, 44, 2, 0, 0, 10, 10, 5, 0, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16,
17, 18, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5,
5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5,
][..],
);
let buf_len = buf.len();
assert_eq!(super::Update::from_bytes(&mut buf, buf_len as u8), None);
// Decode function should still consume the required amount of bytes to leave parser in a
// good state (assuming the length in the tlv preamble is good).
assert_eq!(buf.remaining(), 0);
}
#[test]
fn decode_ignores_invalid_flag_bits() {
// Set all flag bits, only allowed bits should be set on the decoded value
let mut buf = bytes::BytesMut::from(
&[
3, 255, 64, 0, 3, 232, 0, 42, 3, 1, 0, 10, 0, 20, 0, 30, 0, 40, 4, 4, 4, 4, 4, 4,
4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4,
4, 4, 4, 4, 4, 4,
][..],
);
let ihu = super::Update {
flags: super::UPDATE_FLAG_PREFIX | super::UPDATE_FLAG_ROUTER_ID,
interval: 1000,
seqno: 42.into(),
metric: 769.into(),
subnet: Subnet::new(Ipv6Addr::new(0xfe80, 0, 0, 0, 10, 20, 30, 40).into(), 64)
.expect("92 is a valid IPv6 prefix size; qed"),
router_id: RouterId::from([4u8; RouterId::BYTE_SIZE]),
};
let buf_len = buf.len();
assert_eq!(
super::Update::from_bytes(&mut buf, buf_len as u8),
Some(ihu)
);
assert_eq!(buf.remaining(), 0);
}
#[test]
fn roundtrip() {
let mut buf = bytes::BytesMut::new();
let hello_src = super::Update::new(
Duration::from_secs(64),
10.into(),
25.into(),
Subnet::new(
Ipv6Addr::new(0x21f, 0x4025, 0xabcd, 0xdead, 0, 0, 0, 0).into(),
64,
)
.expect("64 is a valid IPv6 prefix size; qed"),
RouterId::from([6; RouterId::BYTE_SIZE]),
);
hello_src.write_bytes(&mut buf);
let buf_len = buf.len();
let decoded = super::Update::from_bytes(&mut buf, buf_len as u8);
assert_eq!(Some(hello_src), decoded);
assert_eq!(buf.remaining(), 0);
}
}

338
mycelium/src/cdn.rs Normal file
View File

@@ -0,0 +1,338 @@
use std::path::PathBuf;
use aes_gcm::{aead::Aead, KeyInit};
use axum::{
extract::{Query, State},
http::{HeaderMap, StatusCode},
routing::get,
Router,
};
use axum_extra::extract::Host;
use futures::{stream::FuturesUnordered, StreamExt};
use reqwest::header::CONTENT_TYPE;
use tokio::net::TcpListener;
use tokio_util::sync::CancellationToken;
use tracing::{debug, error, info, warn};
/// Cdn functionality. Urls of specific format lead to donwnlaoding of metadata from the registry,
/// and serving of chunks.
pub struct Cdn {
cache: PathBuf,
cancel_token: CancellationToken,
}
/// Cache for reconstructed blocks
#[derive(Clone)]
struct Cache {
base: PathBuf,
}
impl Cdn {
pub fn new(cache: PathBuf) -> Self {
let cancel_token = CancellationToken::new();
Self {
cache,
cancel_token,
}
}
/// Start the Cdn server. This future runs until the server is stopped.
pub fn start(&self, listener: TcpListener) -> Result<(), Box<dyn std::error::Error>> {
let state = Cache {
base: self.cache.clone(),
};
if !self.cache.exists() {
info!(dir = %self.cache.display(), "Creating cache dir");
std::fs::create_dir(&self.cache)?;
}
if !self.cache.is_dir() {
return Err("Cache dir is not a directory".into());
}
let router = Router::new().route("/", get(cdn)).with_state(state);
let cancel_token = self.cancel_token.clone();
tokio::spawn(async {
axum::serve(listener, router)
.with_graceful_shutdown(cancel_token.cancelled_owned())
.await
.map_err(|err| {
warn!(%err, "Cdn server error");
})
});
Ok(())
}
}
#[derive(Debug, serde::Deserialize)]
struct DecryptionKeyQuery {
key: Option<String>,
}
#[tracing::instrument(level = tracing::Level::DEBUG, skip(cache))]
async fn cdn(
Host(host): Host,
Query(query): Query<DecryptionKeyQuery>,
State(cache): State<Cache>,
) -> Result<(HeaderMap, Vec<u8>), StatusCode> {
debug!("Received request at {host}");
let mut parts = host.split('.');
let prefix = parts
.next()
.expect("Splitting a String always yields at least 1 result; Qed.");
if prefix.len() != 32 {
return Err(StatusCode::BAD_REQUEST);
}
let mut hash = [0; 16];
faster_hex::hex_decode(prefix.as_bytes(), &mut hash).map_err(|_| StatusCode::BAD_REQUEST)?;
let registry_url = parts.collect::<Vec<_>>().join(".");
let decryption_key = if let Some(query_key) = query.key {
let mut key = [0; 16];
faster_hex::hex_decode(query_key.as_bytes(), &mut key)
.map_err(|_| StatusCode::BAD_REQUEST)?;
Some(key)
} else {
None
};
let meta = load_meta(registry_url.clone(), hash, decryption_key).await?;
debug!("Metadata loaded");
let mut headers = HeaderMap::new();
match meta {
cdn_meta::Metadata::File(file) => {
//
if let Some(mime) = file.mime {
debug!(%mime, "Setting mime type");
headers.append(
CONTENT_TYPE,
mime.parse().map_err(|_| {
warn!("Not serving file with unprocessable mime type");
StatusCode::UNPROCESSABLE_ENTITY
})?,
);
}
// File recombination
let mut content = vec![];
for block in file.blocks {
content.extend_from_slice(cache.fetch_block(&block).await?.as_slice());
}
Ok((headers, content))
}
cdn_meta::Metadata::Directory(dir) => {
let mut out = r#"
<!DOCTYPE html>
<html i18n-values="dir:textdirection;lang:language">
<head>
<meta charset="utf-8">
</head>
<body>
<ul>"#
.to_string();
headers.append(
CONTENT_TYPE,
"text/html"
.parse()
.expect("Can parse \"text/html\" to content-type"),
);
for (file_hash, encryption_key) in dir.files {
let meta = load_meta(registry_url.clone(), file_hash, encryption_key).await?;
let name = match meta {
cdn_meta::Metadata::File(file) => file.name,
cdn_meta::Metadata::Directory(dir) => dir.name,
};
out.push_str(&format!(
"<li><a href=\"http://{}.{registry_url}/?key={}\">{name}</a></li>\n",
faster_hex::hex_string(&file_hash),
&encryption_key
.map(|ek| faster_hex::hex_string(&ek))
.unwrap_or_else(String::new),
));
}
out.push_str("</ul></body></html>");
Ok((headers, out.into()))
}
}
}
/// Load a metadata blob from a metadata repository.
async fn load_meta(
registry_url: String,
hash: cdn_meta::Hash,
encryption_key: Option<cdn_meta::Hash>,
) -> Result<cdn_meta::Metadata, StatusCode> {
let mut r_url = reqwest::Url::parse(&format!("http://{registry_url}")).map_err(|err| {
error!(%err, "Could not parse registry URL");
StatusCode::INTERNAL_SERVER_ERROR
})?;
let hex_hash = faster_hex::hex_string(&hash);
r_url.set_path(&format!("/api/v1/metadata/{hex_hash}"));
r_url.set_scheme("http").map_err(|_| {
error!("Could not set HTTP scheme");
StatusCode::INTERNAL_SERVER_ERROR
})?;
debug!(url = %r_url, "Fetching chunk");
let metadata_reply = reqwest::get(r_url).await.map_err(|err| {
error!(%err, "Could not load metadata from registry");
StatusCode::INTERNAL_SERVER_ERROR
})?;
// TODO: Should we just check if status code is success here?
if metadata_reply.status() != StatusCode::OK {
debug!(
status = %metadata_reply.status(),
"Registry replied with non-OK status code"
);
return Err(metadata_reply.status());
}
let encrypted_metadata = metadata_reply.bytes().await.map_err(|err| {
error!(%err, "Could not load metadata response from registry");
StatusCode::INTERNAL_SERVER_ERROR
})?;
let metadata = if let Some(encryption_key) = encryption_key {
if encrypted_metadata.len() < 12 {
debug!("Attempting to decrypt metadata with inufficient size");
return Err(StatusCode::UNPROCESSABLE_ENTITY);
}
let decryptor = aes_gcm::Aes128Gcm::new(&encryption_key.into());
let plaintext = decryptor
.decrypt(
encrypted_metadata[encrypted_metadata.len() - 12..].into(),
&encrypted_metadata[..encrypted_metadata.len() - 12],
)
.map_err(|_| {
warn!("Decryption of block failed");
// Either the decryption key is wrong or the blob is corrupt, we assume the
// registry is not a fault so the decryption key is wrong, which is a user error.
StatusCode::UNPROCESSABLE_ENTITY
})?;
plaintext
} else {
encrypted_metadata.into()
};
// If the metadata is not decodable, this is not really our fault, but also not the necessarily
// the users fault.
let (meta, consumed) =
cdn_meta::Metadata::from_binary(&metadata).map_err(|_| StatusCode::UNPROCESSABLE_ENTITY)?;
if consumed != metadata.len() {
warn!(
metadata_length = metadata.len(),
consumed, "Trailing binary metadata which wasn't decoded"
);
}
Ok(meta)
}
impl Drop for Cdn {
fn drop(&mut self) {
self.cancel_token.cancel();
}
}
/// Download a shard from a 0-db.
async fn download_shard(
location: &cdn_meta::Location,
key: &[u8],
) -> Result<Vec<u8>, Box<dyn std::error::Error>> {
let client = redis::Client::open(format!("redis://{}", location.host))?;
let mut con = client.get_multiplexed_async_connection().await?;
redis::cmd("SELECT")
.arg(&location.namespace)
.query_async::<()>(&mut con)
.await?;
Ok(redis::cmd("GET").arg(key).query_async(&mut con).await?)
}
impl Cache {
async fn fetch_block(&self, block: &cdn_meta::Block) -> Result<Vec<u8>, StatusCode> {
let mut cached_file_path = self.base.clone();
cached_file_path.push(faster_hex::hex_string(&block.encrypted_hash));
// If we have the file in cache, just open it, load it, and return from there.
if cached_file_path.exists() {
return tokio::fs::read(&cached_file_path).await.map_err(|err| {
error!(%err, "Could not load cached file");
StatusCode::INTERNAL_SERVER_ERROR
});
}
// File is not in cache, download and save
// TODO: Rank based on expected latency
// FIXME: Only download the required amount
let mut shard_stream = block
.shards
.iter()
.enumerate()
.map(|(i, loc)| async move { (i, download_shard(loc, &block.encrypted_hash).await) })
.collect::<FuturesUnordered<_>>();
let mut shards = vec![None; block.shards.len()];
while let Some((idx, shard)) = shard_stream.next().await {
let shard = shard.map_err(|err| {
warn!(err, "Could not load shard");
StatusCode::INTERNAL_SERVER_ERROR
})?;
shards[idx] = Some(shard);
}
// recombine
let encoder = reed_solomon_erasure::galois_8::ReedSolomon::new(
block.required_shards as usize,
block.shards.len() - block.required_shards as usize,
)
.map_err(|err| {
error!(%err, "Failed to construct erausre codec");
StatusCode::INTERNAL_SERVER_ERROR
})?;
encoder.reconstruct_data(&mut shards).map_err(|err| {
error!(%err, "Shard recombination failed");
StatusCode::INTERNAL_SERVER_ERROR
})?;
// SAFETY: Since decoding was succesfull, the first shards (data shards) must be
// Option::Some
let mut encrypted_data = shards
.into_iter()
.map(Option::unwrap)
.take(block.required_shards as usize)
.flatten()
.collect::<Vec<_>>();
let padding_len = encrypted_data[encrypted_data.len() - 1] as usize;
encrypted_data.resize(encrypted_data.len() - padding_len, 0);
let decryptor = aes_gcm::Aes128Gcm::new(&block.content_hash.into());
let c = decryptor
.decrypt(&block.nonce.into(), encrypted_data.as_slice())
.map_err(|err| {
warn!(%err, "Decryption of content block failed");
StatusCode::UNPROCESSABLE_ENTITY
})?;
// Save file to cache, this is not critical if it fails
if let Err(err) = tokio::fs::write(&cached_file_path, &c).await {
warn!(%err, "Could not write block to cache");
};
Ok(c)
}
}

158
mycelium/src/connection.rs Normal file
View File

@@ -0,0 +1,158 @@
use std::{io, net::SocketAddr, pin::Pin};
use tokio::{
io::{AsyncRead, AsyncWrite},
net::TcpStream,
};
mod tracked;
pub use tracked::Tracked;
#[cfg(feature = "private-network")]
mod tls;
/// Cost to add to the peer_link_cost for "local processing", when peers are connected over IPv6.
///
/// The current peer link cost is calculated from a HELLO rtt. This is great to measure link
/// latency, since packets are processed in order. However, on local idle links, this value will
/// likely be 0 since we round down (from the amount of ms it took to process), which does not
/// accurately reflect the fact that there is in fact a cost associated with using a peer, even on
/// these local links.
const PACKET_PROCESSING_COST_IP6_TCP: u16 = 10;
/// Cost to add to the peer_link_cost for "local processing", when peers are connected over IPv6.
///
/// This is similar to [`PACKET_PROCESSING_COST_IP6`], but slightly higher so we skew towards IPv6
/// connections if peers are connected over both IPv4 and IPv6.
const PACKET_PROCESSING_COST_IP4_TCP: u16 = 15;
// TODO
const PACKET_PROCESSING_COST_IP6_QUIC: u16 = 7;
// TODO
const PACKET_PROCESSING_COST_IP4_QUIC: u16 = 12;
pub trait Connection: AsyncRead + AsyncWrite {
/// Get an identifier for this connection, which shows details about the remote
fn identifier(&self) -> Result<String, io::Error>;
/// The static cost of using this connection
fn static_link_cost(&self) -> Result<u16, io::Error>;
}
/// A wrapper around a quic send and quic receive stream, implementing the [`Connection`] trait.
pub struct Quic {
tx: quinn::SendStream,
rx: quinn::RecvStream,
remote: SocketAddr,
}
impl Quic {
/// Create a new wrapper around Quic streams.
pub fn new(tx: quinn::SendStream, rx: quinn::RecvStream, remote: SocketAddr) -> Self {
Quic { tx, rx, remote }
}
}
impl Connection for TcpStream {
fn identifier(&self) -> Result<String, io::Error> {
Ok(format!(
"TCP {} <-> {}",
self.local_addr()?,
self.peer_addr()?
))
}
fn static_link_cost(&self) -> Result<u16, io::Error> {
Ok(match self.peer_addr()? {
SocketAddr::V4(_) => PACKET_PROCESSING_COST_IP4_TCP,
SocketAddr::V6(ip) if ip.ip().to_ipv4_mapped().is_some() => {
PACKET_PROCESSING_COST_IP4_TCP
}
SocketAddr::V6(_) => PACKET_PROCESSING_COST_IP6_TCP,
})
}
}
impl AsyncRead for Quic {
#[inline]
fn poll_read(
mut self: std::pin::Pin<&mut Self>,
cx: &mut std::task::Context<'_>,
buf: &mut tokio::io::ReadBuf<'_>,
) -> std::task::Poll<io::Result<()>> {
Pin::new(&mut self.rx).poll_read(cx, buf)
}
}
impl AsyncWrite for Quic {
#[inline]
fn poll_write(
mut self: Pin<&mut Self>,
cx: &mut std::task::Context<'_>,
buf: &[u8],
) -> std::task::Poll<Result<usize, io::Error>> {
Pin::new(&mut self.tx)
.poll_write(cx, buf)
.map_err(From::from)
}
#[inline]
fn poll_flush(
mut self: Pin<&mut Self>,
cx: &mut std::task::Context<'_>,
) -> std::task::Poll<Result<(), io::Error>> {
Pin::new(&mut self.tx).poll_flush(cx)
}
#[inline]
fn poll_shutdown(
mut self: Pin<&mut Self>,
cx: &mut std::task::Context<'_>,
) -> std::task::Poll<Result<(), io::Error>> {
Pin::new(&mut self.tx).poll_shutdown(cx)
}
#[inline]
fn poll_write_vectored(
mut self: Pin<&mut Self>,
cx: &mut std::task::Context<'_>,
bufs: &[io::IoSlice<'_>],
) -> std::task::Poll<Result<usize, io::Error>> {
Pin::new(&mut self.tx).poll_write_vectored(cx, bufs)
}
#[inline]
fn is_write_vectored(&self) -> bool {
self.tx.is_write_vectored()
}
}
impl Connection for Quic {
fn identifier(&self) -> Result<String, io::Error> {
Ok(format!("QUIC -> {}", self.remote))
}
fn static_link_cost(&self) -> Result<u16, io::Error> {
Ok(match self.remote {
SocketAddr::V4(_) => PACKET_PROCESSING_COST_IP4_QUIC,
SocketAddr::V6(ip) if ip.ip().to_ipv4_mapped().is_some() => {
PACKET_PROCESSING_COST_IP4_QUIC
}
SocketAddr::V6(_) => PACKET_PROCESSING_COST_IP6_QUIC,
})
}
}
#[cfg(test)]
use tokio::io::DuplexStream;
#[cfg(test)]
impl Connection for DuplexStream {
fn identifier(&self) -> Result<String, io::Error> {
Ok("Memory pipe".to_string())
}
fn static_link_cost(&self) -> Result<u16, io::Error> {
Ok(1)
}
}

View File

@@ -0,0 +1,23 @@
use std::{io, net::SocketAddr};
use tokio::net::TcpStream;
impl super::Connection for tokio_openssl::SslStream<TcpStream> {
fn identifier(&self) -> Result<String, io::Error> {
Ok(format!(
"TLS {} <-> {}",
self.get_ref().local_addr()?,
self.get_ref().peer_addr()?
))
}
fn static_link_cost(&self) -> Result<u16, io::Error> {
Ok(match self.get_ref().peer_addr()? {
SocketAddr::V4(_) => super::PACKET_PROCESSING_COST_IP4_TCP,
SocketAddr::V6(ip) if ip.ip().to_ipv4_mapped().is_some() => {
super::PACKET_PROCESSING_COST_IP4_TCP
}
SocketAddr::V6(_) => super::PACKET_PROCESSING_COST_IP6_TCP,
})
}
}

View File

@@ -0,0 +1,120 @@
use std::{
pin::Pin,
sync::{
atomic::{AtomicU64, Ordering},
Arc,
},
task::Poll,
};
use tokio::io::{AsyncRead, AsyncWrite};
use super::Connection;
/// Wrapper which keeps track of how much bytes have been read and written from a connection.
pub struct Tracked<C> {
/// Bytes read counter
read: Arc<AtomicU64>,
/// Bytes written counter
write: Arc<AtomicU64>,
/// Underlying connection we are measuring
con: C,
}
impl<C> Tracked<C>
where
C: Connection + Unpin,
{
/// Create a new instance of a tracked connections. Counters are passed in so they can be
/// reused accross connections.
pub fn new(read: Arc<AtomicU64>, write: Arc<AtomicU64>, con: C) -> Self {
Self { read, write, con }
}
}
impl<C> Connection for Tracked<C>
where
C: Connection + Unpin,
{
#[inline]
fn identifier(&self) -> Result<String, std::io::Error> {
self.con.identifier()
}
#[inline]
fn static_link_cost(&self) -> Result<u16, std::io::Error> {
self.con.static_link_cost()
}
}
impl<C> AsyncRead for Tracked<C>
where
C: AsyncRead + Unpin,
{
#[inline]
fn poll_read(
mut self: std::pin::Pin<&mut Self>,
cx: &mut std::task::Context<'_>,
buf: &mut tokio::io::ReadBuf<'_>,
) -> std::task::Poll<std::io::Result<()>> {
let start_len = buf.filled().len();
let res = Pin::new(&mut self.con).poll_read(cx, buf);
if let Poll::Ready(Ok(())) = res {
self.read
.fetch_add((buf.filled().len() - start_len) as u64, Ordering::Relaxed);
}
res
}
}
impl<C> AsyncWrite for Tracked<C>
where
C: AsyncWrite + Unpin,
{
#[inline]
fn poll_write(
mut self: Pin<&mut Self>,
cx: &mut std::task::Context<'_>,
buf: &[u8],
) -> Poll<Result<usize, std::io::Error>> {
let res = Pin::new(&mut self.con).poll_write(cx, buf);
if let Poll::Ready(Ok(written)) = res {
self.write.fetch_add(written as u64, Ordering::Relaxed);
}
res
}
#[inline]
fn poll_flush(
mut self: Pin<&mut Self>,
cx: &mut std::task::Context<'_>,
) -> Poll<Result<(), std::io::Error>> {
Pin::new(&mut self.con).poll_flush(cx)
}
#[inline]
fn poll_shutdown(
mut self: Pin<&mut Self>,
cx: &mut std::task::Context<'_>,
) -> Poll<Result<(), std::io::Error>> {
Pin::new(&mut self.con).poll_shutdown(cx)
}
#[inline]
fn poll_write_vectored(
mut self: Pin<&mut Self>,
cx: &mut std::task::Context<'_>,
bufs: &[std::io::IoSlice<'_>],
) -> Poll<Result<usize, std::io::Error>> {
let res = Pin::new(&mut self.con).poll_write_vectored(cx, bufs);
if let Poll::Ready(Ok(written)) = res {
self.write.fetch_add(written as u64, Ordering::Relaxed);
}
res
}
#[inline]
fn is_write_vectored(&self) -> bool {
self.con.is_write_vectored()
}
}

450
mycelium/src/crypto.rs Normal file
View File

@@ -0,0 +1,450 @@
//! Abstraction over diffie hellman, symmetric encryption, and hashing.
use core::fmt;
use std::{
error::Error,
fmt::Display,
net::Ipv6Addr,
ops::{Deref, DerefMut},
};
use aes_gcm::{aead::OsRng, AeadCore, AeadInPlace, Aes256Gcm, Key, KeyInit};
use serde::{de::Visitor, Deserialize, Serialize};
/// Default MTU for a packet. Ideally this would not be needed and the [`PacketBuffer`] takes a
/// const generic argument which is then expanded with the needed extra space for the buffer,
/// however as it stands const generics can only be used standalone and not in a constant
/// expression. This _is_ possible on nightly rust, with a feature gate (generic_const_exprs).
const PACKET_SIZE: usize = 1400;
/// Size of an AES_GCM tag in bytes.
const AES_TAG_SIZE: usize = 16;
/// Size of an AES_GCM nonce in bytes.
const AES_NONCE_SIZE: usize = 12;
/// Size of user defined data header. This header will be part of the encrypted data.
const DATA_HEADER_SIZE: usize = 4;
/// Size of a `PacketBuffer`.
const PACKET_BUFFER_SIZE: usize = PACKET_SIZE + AES_TAG_SIZE + AES_NONCE_SIZE + DATA_HEADER_SIZE;
/// A public key used as part of Diffie Hellman key exchange. It is derived from a [`SecretKey`].
#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)]
pub struct PublicKey(x25519_dalek::PublicKey);
/// A secret used as part of Diffie Hellman key exchange.
///
/// This type intentionally does not implement or derive [`Debug`] to avoid accidentally leaking
/// secrets in logs.
#[derive(Clone)]
pub struct SecretKey(x25519_dalek::StaticSecret);
/// A statically computed secret from a [`SecretKey`] and a [`PublicKey`].
///
/// This type intentionally does not implement or derive [`Debug`] to avoid accidentally leaking
/// secrets in logs.
#[derive(Clone)]
pub struct SharedSecret([u8; 32]);
/// A buffer for packets. This holds enough space to encrypt a packet in place without
/// reallocating.
///
/// Internally, the buffer is created with an additional header. Because this header is part of the
/// encrypted content, it is not included in the global version set by the main packet header. As
/// such, an internal version is included.
pub struct PacketBuffer {
buf: Vec<u8>,
/// Amount of bytes written in the buffer
size: usize,
}
/// A reference to the header in a [`PacketBuffer`].
pub struct PacketBufferHeader<'a> {
data: &'a [u8; DATA_HEADER_SIZE],
}
/// A mutable reference to the header in a [`PacketBuffer`].
pub struct PacketBufferHeaderMut<'a> {
data: &'a mut [u8; DATA_HEADER_SIZE],
}
/// Opaque type indicating decryption failed.
#[derive(Debug, Clone, Copy)]
pub struct DecryptionError;
impl Display for DecryptionError {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
f.write_str("Decryption failed, invalid or insufficient encrypted content for this key")
}
}
impl Error for DecryptionError {}
impl SecretKey {
/// Generate a new `StaticSecret` using [`OsRng`] as an entropy source.
pub fn new() -> Self {
SecretKey(x25519_dalek::StaticSecret::random_from_rng(OsRng))
}
/// View this `SecretKey` as a byte array.
#[inline]
pub fn as_bytes(&self) -> &[u8; 32] {
self.0.as_bytes()
}
/// Computes the [`SharedSecret`] from this `SecretKey` and a [`PublicKey`].
pub fn shared_secret(&self, other: &PublicKey) -> SharedSecret {
SharedSecret(self.0.diffie_hellman(&other.0).to_bytes())
}
}
impl Default for SecretKey {
fn default() -> Self {
Self::new()
}
}
impl PublicKey {
/// Generates an [`Ipv6Addr`] from a `PublicKey`.
///
/// The generated address is guaranteed to be part of the `400::/7` range.
pub fn address(&self) -> Ipv6Addr {
let mut hasher = blake3::Hasher::new();
hasher.update(self.as_bytes());
let mut buf = [0; 16];
hasher.finalize_xof().fill(&mut buf);
// Mangle the first byte to be of the expected form. Because of the network range
// requirement, we MUST set the third bit, and MAY set the last bit. Instead of discarding
// the first 7 bits of the hash, use the first byte to determine if the last bit is set.
// If there is an odd number of bits set in the first byte, set the last bit of the result.
let lsb = buf[0].count_ones() as u8 % 2;
buf[0] = 0x04 | lsb;
Ipv6Addr::from(buf)
}
/// Convert this `PublicKey` to a byte array.
pub fn to_bytes(self) -> [u8; 32] {
self.0.to_bytes()
}
/// View this `PublicKey` as a byte array.
pub fn as_bytes(&self) -> &[u8; 32] {
self.0.as_bytes()
}
}
impl SharedSecret {
/// Encrypt a [`PacketBuffer`] using the `SharedSecret` as key.
///
/// Internally, a new random nonce will be generated using the OS's crypto rng generator. This
/// nonce is appended to the encrypted data.
pub fn encrypt(&self, mut data: PacketBuffer) -> Vec<u8> {
let key: Key<Aes256Gcm> = self.0.into();
let nonce = Aes256Gcm::generate_nonce(OsRng);
let cipher = Aes256Gcm::new(&key);
let tag = cipher
.encrypt_in_place_detached(&nonce, &[], &mut data.buf[..data.size])
.expect("Encryption can't fail; qed.");
data.buf[data.size..data.size + AES_TAG_SIZE].clone_from_slice(tag.as_slice());
data.buf[data.size + AES_TAG_SIZE..data.size + AES_TAG_SIZE + AES_NONCE_SIZE]
.clone_from_slice(&nonce);
data.buf.truncate(data.size + AES_NONCE_SIZE + AES_TAG_SIZE);
data.buf
}
/// Decrypt a message previously encrypted with an equivalent `SharedSecret`. In other words, a
/// message that was previously created by the [`SharedSecret::encrypt`] method.
///
/// Internally, this messages assumes that a 12 byte nonce is present at the end of the data.
/// If the passed in data to decrypt does not contain a valid nonce, decryption fails and an
/// opaque error is returned. As an extension to this, if the data is not of sufficient length
/// to contain a valid nonce, an error is returned immediately.
pub fn decrypt(&self, mut data: Vec<u8>) -> Result<PacketBuffer, DecryptionError> {
// Make sure we have sufficient data (i.e. a nonce).
if data.len() < AES_NONCE_SIZE + AES_TAG_SIZE + DATA_HEADER_SIZE {
return Err(DecryptionError);
}
let data_len = data.len();
let key: Key<Aes256Gcm> = self.0.into();
{
let (data, nonce) = data.split_at_mut(data_len - AES_NONCE_SIZE);
let (data, tag) = data.split_at_mut(data.len() - AES_TAG_SIZE);
let cipher = Aes256Gcm::new(&key);
cipher
.decrypt_in_place_detached((&*nonce).into(), &[], data, (&*tag).into())
.map_err(|_| DecryptionError)?;
}
Ok(PacketBuffer {
// We did not remove the scratch space used for TAG and NONCE.
size: data.len() - AES_TAG_SIZE - AES_NONCE_SIZE,
buf: data,
})
}
}
impl PacketBuffer {
/// Create a new blank `PacketBuffer`.
pub fn new() -> Self {
Self {
buf: vec![0; PACKET_BUFFER_SIZE],
size: 0,
}
}
/// Get a reference to the packet header.
pub fn header(&self) -> PacketBufferHeader<'_> {
PacketBufferHeader {
data: self.buf[..DATA_HEADER_SIZE]
.try_into()
.expect("Header size constant is correct; qed"),
}
}
/// Get a mutable reference to the packet header.
pub fn header_mut(&mut self) -> PacketBufferHeaderMut<'_> {
PacketBufferHeaderMut {
data: <&mut [u8] as TryInto<&mut [u8; DATA_HEADER_SIZE]>>::try_into(
&mut self.buf[..DATA_HEADER_SIZE],
)
.expect("Header size constant is correct; qed"),
}
}
/// Get a reference to the entire useable inner buffer.
pub fn buffer(&self) -> &[u8] {
let buf_end = self.buf.len() - AES_NONCE_SIZE - AES_TAG_SIZE;
&self.buf[DATA_HEADER_SIZE..buf_end]
}
/// Get a mutable reference to the entire useable internal buffer.
pub fn buffer_mut(&mut self) -> &mut [u8] {
let buf_end = self.buf.len() - AES_NONCE_SIZE - AES_TAG_SIZE;
&mut self.buf[DATA_HEADER_SIZE..buf_end]
}
/// Sets the amount of bytes in use by the buffer.
pub fn set_size(&mut self, size: usize) {
self.size = size + DATA_HEADER_SIZE;
}
}
impl Default for PacketBuffer {
fn default() -> Self {
Self::new()
}
}
impl From<[u8; 32]> for SecretKey {
/// Load a secret key from a byte array.
fn from(bytes: [u8; 32]) -> SecretKey {
SecretKey(x25519_dalek::StaticSecret::from(bytes))
}
}
impl fmt::Display for PublicKey {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
f.write_str(&faster_hex::hex_string(self.as_bytes()))
}
}
impl Serialize for PublicKey {
fn serialize<S>(&self, serializer: S) -> Result<S::Ok, S::Error>
where
S: serde::Serializer,
{
serializer.serialize_str(&faster_hex::hex_string(self.as_bytes()))
}
}
struct PublicKeyVisitor;
impl Visitor<'_> for PublicKeyVisitor {
type Value = PublicKey;
fn expecting(&self, formatter: &mut std::fmt::Formatter) -> std::fmt::Result {
formatter.write_str("A hex encoded public key (64 characters)")
}
fn visit_str<E>(self, v: &str) -> Result<Self::Value, E>
where
E: serde::de::Error,
{
if v.len() != 64 {
Err(E::custom("Public key is 64 characters long"))
} else {
let mut backing = [0; 32];
faster_hex::hex_decode(v.as_bytes(), &mut backing)
.map_err(|_| E::custom("PublicKey is not valid hex"))?;
Ok(PublicKey(backing.into()))
}
}
}
impl<'de> Deserialize<'de> for PublicKey {
fn deserialize<D>(deserializer: D) -> Result<Self, D::Error>
where
D: serde::Deserializer<'de>,
{
deserializer.deserialize_str(PublicKeyVisitor)
}
}
impl From<[u8; 32]> for PublicKey {
/// Given a byte array, construct a `PublicKey`.
fn from(bytes: [u8; 32]) -> PublicKey {
PublicKey(x25519_dalek::PublicKey::from(bytes))
}
}
impl TryFrom<&str> for PublicKey {
type Error = faster_hex::Error;
fn try_from(value: &str) -> Result<Self, Self::Error> {
let mut output = [0u8; 32];
faster_hex::hex_decode(value.as_bytes(), &mut output)?;
Ok(PublicKey::from(output))
}
}
impl From<&SecretKey> for PublicKey {
fn from(value: &SecretKey) -> Self {
PublicKey(x25519_dalek::PublicKey::from(&value.0))
}
}
impl Deref for SharedSecret {
type Target = [u8; 32];
fn deref(&self) -> &Self::Target {
&self.0
}
}
impl Deref for PacketBuffer {
type Target = [u8];
fn deref(&self) -> &Self::Target {
&self.buf[DATA_HEADER_SIZE..self.size]
}
}
impl Deref for PacketBufferHeader<'_> {
type Target = [u8; DATA_HEADER_SIZE];
fn deref(&self) -> &Self::Target {
self.data
}
}
impl Deref for PacketBufferHeaderMut<'_> {
type Target = [u8; DATA_HEADER_SIZE];
fn deref(&self) -> &Self::Target {
self.data
}
}
impl DerefMut for PacketBufferHeaderMut<'_> {
fn deref_mut(&mut self) -> &mut Self::Target {
self.data
}
}
impl fmt::Debug for PacketBuffer {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
f.debug_struct("PacketBuffer")
.field("data", &"...")
.field("len", &self.size)
.finish()
}
}
#[cfg(test)]
mod tests {
use super::{PacketBuffer, SecretKey, AES_NONCE_SIZE, AES_TAG_SIZE, DATA_HEADER_SIZE};
#[test]
/// Test if encryption works in general. We just create some random value and encrypt it.
/// Specifically, this will help to catch runtime panics in case AES_TAG_SIZE or AES_NONCE_SIZE
/// don't have a proper value aligned with the underlying AES_GCM implementation.
fn encryption_succeeds() {
let k1 = SecretKey::new();
let k2 = SecretKey::new();
let ss = k1.shared_secret(&(&k2).into());
let mut pb = PacketBuffer::new();
let data = b"vnno30nv f654q364 vfsv 44"; // Random keyboard smash.
pb.buffer_mut()[..data.len()].copy_from_slice(data);
pb.set_size(data.len());
// We only care that this does not panic.
let res = ss.encrypt(pb);
// At the same time, check expected size.
assert_eq!(
res.len(),
data.len() + DATA_HEADER_SIZE + AES_TAG_SIZE + AES_NONCE_SIZE
);
}
#[test]
/// Encrypt a value and then decrypt it. This makes sure the decrypt flow and encrypt flow
/// match, and both follow the expected format. Also, we don't reuse the shared secret for
/// decryption, but instead generate the secret again the other way round, to simulate a remote
/// node.
fn encrypt_decrypt_roundtrip() {
let k1 = SecretKey::new();
let k2 = SecretKey::new();
let ss1 = k1.shared_secret(&(&k2).into());
let ss2 = k2.shared_secret(&(&k1).into());
// This assertion is not strictly necessary as it will be checked below implicitly.
assert_eq!(ss1.as_slice(), ss2.as_slice());
let data = b"dsafjiqjo23 u2953u8 3oid fjo321j";
let mut pb = PacketBuffer::new();
pb.buffer_mut()[..data.len()].copy_from_slice(data);
pb.set_size(data.len());
let res = ss1.encrypt(pb);
let original = ss2.decrypt(res).expect("Decryption works");
assert_eq!(&*original, &data[..]);
}
#[test]
/// Test if PacketBufferHeaderMut actually modifies the PacketBuffer storage.
fn modify_header() {
let mut pb = PacketBuffer::new();
let mut header = pb.header_mut();
header[0] = 1;
header[1] = 2;
header[2] = 3;
header[3] = 4;
assert_eq!(pb.buf[..DATA_HEADER_SIZE], [1, 2, 3, 4]);
}
#[test]
/// Verify [`PacketBuffer::buffer`] and [`PacketBuffer::buffer_mut`] actually have the
/// appropriate size.
fn buffer_mapping() {
let mut pb = PacketBuffer::new();
assert_eq!(pb.buffer().len(), super::PACKET_SIZE);
assert_eq!(pb.buffer_mut().len(), super::PACKET_SIZE);
}
}

487
mycelium/src/data.rs Normal file
View File

@@ -0,0 +1,487 @@
use std::net::{IpAddr, Ipv6Addr};
use etherparse::{
icmpv6::{DestUnreachableCode, TimeExceededCode},
Icmpv6Type, PacketBuilder,
};
use futures::{Sink, SinkExt, Stream, StreamExt};
use tokio::sync::mpsc::UnboundedReceiver;
use tracing::{debug, error, trace, warn};
use crate::{crypto::PacketBuffer, metrics::Metrics, packet::DataPacket, router::Router};
/// Current version of the user data header.
const USER_DATA_VERSION: u8 = 1;
/// Type value indicating L3 data in the user data header.
const USER_DATA_L3_TYPE: u8 = 0;
/// Type value indicating a user message in the data header.
const USER_DATA_MESSAGE_TYPE: u8 = 1;
/// Type value indicating an ICMP packet not returned as regular IPv6 traffic. This is needed when
/// intermediate nodes send back icmp data, as the original data is encrypted.
const USER_DATA_OOB_ICMP: u8 = 2;
/// Minimum size in bytes of an IPv6 header.
const IPV6_MIN_HEADER_SIZE: usize = 40;
/// Size of an ICMPv6 header.
const ICMP6_HEADER_SIZE: usize = 8;
/// Minimum MTU for IPV6 according to https://www.rfc-editor.org/rfc/rfc8200#section-5.
/// For ICMP, the packet must not be greater than this value. This is specified in
/// https://datatracker.ietf.org/doc/html/rfc4443#section-2.4, section (c).
const MIN_IPV6_MTU: usize = 1280;
/// Mask applied to the first byte of an IP header to extract the version.
const IP_VERSION_MASK: u8 = 0b1111_0000;
/// Version byte of an IP header indicating IPv6. Since the version is only 4 bits, the lower bits
/// must be masked first.
const IPV6_VERSION_BYTE: u8 = 0b0110_0000;
/// Default hop limit for message packets. For now this is set to 64 hops.
///
/// For regular l3 packets, we copy the hop limit from the packet itself. We can't do that here, so
/// 64 is used as sane default.
const MESSAGE_HOP_LIMIT: u8 = 64;
/// The DataPlane manages forwarding/receiving of local data packets to the [`Router`], and the
/// encryption/decryption of them.
///
/// DataPlane itself can be cloned, but this is not cheap on the router and should be avoided.
pub struct DataPlane<M> {
router: Router<M>,
}
impl<M> DataPlane<M>
where
M: Metrics + Clone + Send + 'static,
{
/// Create a new `DataPlane` using the given [`Router`] for packet handling.
///
/// `l3_packet_stream` is a stream of l3 packets from the host, usually read from a TUN interface.
/// `l3_packet_sink` is a sink for l3 packets received from a romte, usually send to a TUN interface,
pub fn new<S, T, U>(
router: Router<M>,
l3_packet_stream: S,
l3_packet_sink: T,
message_packet_sink: U,
host_packet_source: UnboundedReceiver<DataPacket>,
) -> Self
where
S: Stream<Item = Result<PacketBuffer, std::io::Error>> + Send + Unpin + 'static,
T: Sink<PacketBuffer> + Clone + Send + Unpin + 'static,
T::Error: std::fmt::Display,
U: Sink<(PacketBuffer, IpAddr, IpAddr)> + Send + Unpin + 'static,
U::Error: std::fmt::Display,
{
let dp = Self { router };
tokio::spawn(
dp.clone()
.inject_l3_packet_loop(l3_packet_stream, l3_packet_sink.clone()),
);
tokio::spawn(dp.clone().extract_packet_loop(
l3_packet_sink,
message_packet_sink,
host_packet_source,
));
dp
}
/// Get a reference to the [`Router`] used.
pub fn router(&self) -> &Router<M> {
&self.router
}
async fn inject_l3_packet_loop<S, T>(self, mut l3_packet_stream: S, mut l3_packet_sink: T)
where
// TODO: no result
// TODO: should IP extraction be handled higher up?
S: Stream<Item = Result<PacketBuffer, std::io::Error>> + Send + Unpin + 'static,
T: Sink<PacketBuffer> + Clone + Send + Unpin + 'static,
T::Error: std::fmt::Display,
{
let node_subnet = self.router.node_tun_subnet();
while let Some(packet) = l3_packet_stream.next().await {
let mut packet = match packet {
Err(e) => {
error!("Failed to read packet from TUN interface {e}");
continue;
}
Ok(packet) => packet,
};
trace!("Received packet from tun");
// Parse an IPv6 header. We don't care about the full header in reality. What we want
// to know is:
// - This is an IPv6 header
// - Hop limit
// - Source address
// - Destination address
// This translates to the following requirements:
// - at least 40 bytes of data, as that is the minimum size of an IPv6 header
// - first 4 bits (version) are the constant 6 (0b0110)
// - src is byte 9-24 (8-23 0 indexed).
// - dst is byte 25-40 (24-39 0 indexed).
if packet.len() < IPV6_MIN_HEADER_SIZE {
trace!("Packet can't contain an IPv6 header");
continue;
}
if packet[0] & IP_VERSION_MASK != IPV6_VERSION_BYTE {
trace!("Packet is not IPv6");
continue;
}
let hop_limit = u8::from_be_bytes([packet[7]]);
let src_ip = Ipv6Addr::from(
<&[u8] as TryInto<[u8; 16]>>::try_into(&packet[8..24])
.expect("Static range bounds on slice are correct length"),
);
let dst_ip = Ipv6Addr::from(
<&[u8] as TryInto<[u8; 16]>>::try_into(&packet[24..40])
.expect("Static range bounds on slice are correct length"),
);
// If this is a packet for our own Subnet, it means there is no local configuration for
// the destination ip or /64 subnet, and the IP is unreachable
if node_subnet.contains_ip(dst_ip.into()) {
trace!(
"Replying to local packet for unexisting address: {}",
dst_ip
);
let mut icmp_packet = PacketBuffer::new();
let host = self.router.node_public_key().address().octets();
let icmp = PacketBuilder::ipv6(host, src_ip.octets(), 64).icmpv6(
Icmpv6Type::DestinationUnreachable(DestUnreachableCode::Address),
);
icmp_packet.set_size(icmp.size(packet.len().min(1280 - 48)));
let mut writer = &mut icmp_packet.buffer_mut()[..];
if let Err(e) = icmp.write(&mut writer, &packet[..packet.len().min(1280 - 48)]) {
error!("Failed to construct ICMP packet: {e}");
continue;
}
if let Err(e) = l3_packet_sink.send(icmp_packet).await {
error!("Failed to send ICMP packet to host: {e}");
}
continue;
}
trace!("Received packet from TUN with dest addr: {:?}", dst_ip);
// Check if the source address is part of 400::/7
let first_src_byte = src_ip.segments()[0] >> 8;
if !(0x04..0x06).contains(&first_src_byte) {
let mut icmp_packet = PacketBuffer::new();
let host = self.router.node_public_key().address().octets();
let icmp = PacketBuilder::ipv6(host, src_ip.octets(), 64).icmpv6(
Icmpv6Type::DestinationUnreachable(
DestUnreachableCode::SourceAddressFailedPolicy,
),
);
icmp_packet.set_size(icmp.size(packet.len().min(1280 - 48)));
let mut writer = &mut icmp_packet.buffer_mut()[..];
if let Err(e) = icmp.write(&mut writer, &packet[..packet.len().min(1280 - 48)]) {
error!("Failed to construct ICMP packet: {e}");
continue;
}
if let Err(e) = l3_packet_sink.send(icmp_packet).await {
error!("Failed to send ICMP packet to host: {e}");
}
continue;
}
// No need to verify destination address, if it is not part of the global subnet there
// should not be a route for it, and therefore the route step will generate the
// appropriate ICMP.
let mut header = packet.header_mut();
header[0] = USER_DATA_VERSION;
header[1] = USER_DATA_L3_TYPE;
if let Some(icmp) = self.encrypt_and_route_packet(src_ip, dst_ip, hop_limit, packet) {
if let Err(e) = l3_packet_sink.send(icmp).await {
error!("Could not forward icmp packet back to TUN interface {e}");
}
}
}
warn!("Data inject loop from host to router ended");
}
/// Inject a new packet where the content is a `message` fragment.
pub fn inject_message_packet(
&self,
src_ip: Ipv6Addr,
dst_ip: Ipv6Addr,
mut packet: PacketBuffer,
) {
let mut header = packet.header_mut();
header[0] = USER_DATA_VERSION;
header[1] = USER_DATA_MESSAGE_TYPE;
self.encrypt_and_route_packet(src_ip, dst_ip, MESSAGE_HOP_LIMIT, packet);
}
/// Encrypt the content of a packet based on the destination key, and then inject the packet
/// into the [`Router`] for processing.
///
/// If no key exists for the destination, the content can'be encrypted, the packet is not injected
/// into the router, and a packet is returned containing an ICMP packet. Note that a return
/// value of [`Option::None`] does not mean the packet was successfully forwarded;
fn encrypt_and_route_packet(
&self,
src_ip: Ipv6Addr,
dst_ip: Ipv6Addr,
hop_limit: u8,
packet: PacketBuffer,
) -> Option<PacketBuffer> {
// If the packet only has a TTL of 1, we won't be able to route it to the destination
// regardless, so just reply with an unencrypted TTL exceeded ICMP.
if hop_limit < 2 {
debug!(
packet.ttl = hop_limit,
packet.src = %src_ip,
packet.dst = %dst_ip,
"Attempting to route packet with insufficient TTL",
);
let mut pb = PacketBuffer::new();
// From self to self
let icmp = PacketBuilder::ipv6(src_ip.octets(), src_ip.octets(), hop_limit)
.icmpv6(Icmpv6Type::TimeExceeded(TimeExceededCode::HopLimitExceeded));
// Scale to max size if needed
let orig_buf_end = packet
.buffer()
.len()
.min(MIN_IPV6_MTU - IPV6_MIN_HEADER_SIZE - ICMP6_HEADER_SIZE);
pb.set_size(icmp.size(orig_buf_end));
let mut b = pb.buffer_mut();
if let Err(e) = icmp.write(&mut b, &packet.buffer()[..orig_buf_end]) {
error!("Failed to construct time exceeded ICMP packet {e}");
return None;
}
return Some(pb);
}
// Get shared secret from node and dest address
let shared_secret = match self.router.get_shared_secret_if_selected(dst_ip.into()) {
Some(ss) => ss,
// If we don't have a route to the destination subnet, reply with ICMP no route to
// host. Do this here as well to avoid encrypting the ICMP to ourselves.
None => {
debug!(
packet.src = %src_ip,
packet.dst = %dst_ip,
"No entry found for destination address, dropping packet",
);
let mut pb = PacketBuffer::new();
// From self to self
let icmp = PacketBuilder::ipv6(src_ip.octets(), src_ip.octets(), hop_limit).icmpv6(
Icmpv6Type::DestinationUnreachable(DestUnreachableCode::NoRoute),
);
// Scale to max size if needed
let orig_buf_end = packet
.buffer()
.len()
.min(MIN_IPV6_MTU - IPV6_MIN_HEADER_SIZE - ICMP6_HEADER_SIZE);
pb.set_size(icmp.size(orig_buf_end));
let mut b = pb.buffer_mut();
if let Err(e) = icmp.write(&mut b, &packet.buffer()[..orig_buf_end]) {
error!("Failed to construct no route to host ICMP packet {e}");
return None;
}
return Some(pb);
}
};
self.router.route_packet(DataPacket {
dst_ip,
src_ip,
hop_limit,
raw_data: shared_secret.encrypt(packet),
});
None
}
async fn extract_packet_loop<T, U>(
self,
mut l3_packet_sink: T,
mut message_packet_sink: U,
mut host_packet_source: UnboundedReceiver<DataPacket>,
) where
T: Sink<PacketBuffer> + Send + Unpin + 'static,
T::Error: std::fmt::Display,
U: Sink<(PacketBuffer, IpAddr, IpAddr)> + Send + Unpin + 'static,
U::Error: std::fmt::Display,
{
while let Some(data_packet) = host_packet_source.recv().await {
// decrypt & send to TUN interface
let shared_secret = if let Some(ss) = self
.router
.get_shared_secret_from_dest(data_packet.src_ip.into())
{
ss
} else {
trace!("Received packet from unknown sender");
continue;
};
let mut decrypted_packet = match shared_secret.decrypt(data_packet.raw_data) {
Ok(data) => data,
Err(_) => {
debug!("Dropping data packet with invalid encrypted content");
continue;
}
};
// Check header
let header = decrypted_packet.header();
if header[0] != USER_DATA_VERSION {
trace!("Dropping decrypted packet with unknown header version");
continue;
}
// Route based on packet type.
match header[1] {
USER_DATA_L3_TYPE => {
let real_packet = decrypted_packet.buffer_mut();
if real_packet.len() < IPV6_MIN_HEADER_SIZE {
debug!(
"Decrypted packet is too short, can't possibly be a valid IPv6 packet"
);
continue;
}
// Adjust the hop limit in the decrypted packet to the new value.
real_packet[7] = data_packet.hop_limit;
if let Err(e) = l3_packet_sink.send(decrypted_packet).await {
error!("Failed to send packet on local TUN interface: {e}",);
continue;
}
}
USER_DATA_MESSAGE_TYPE => {
if let Err(e) = message_packet_sink
.send((
decrypted_packet,
IpAddr::V6(data_packet.src_ip),
IpAddr::V6(data_packet.dst_ip),
))
.await
{
error!("Failed to send packet to message handler: {e}",);
continue;
}
}
USER_DATA_OOB_ICMP => {
let real_packet = &*decrypted_packet;
if real_packet.len() < IPV6_MIN_HEADER_SIZE + ICMP6_HEADER_SIZE + 16 {
debug!(
"Decrypted packet is too short, can't possibly be a valid IPv6 ICMP packet"
);
continue;
}
if real_packet.len() > MIN_IPV6_MTU + 16 {
debug!("Discarding ICMP packet which is too large");
continue;
}
let dec_ip = Ipv6Addr::from(
<&[u8] as TryInto<[u8; 16]>>::try_into(&real_packet[..16]).unwrap(),
);
trace!("ICMP for original target {dec_ip}");
let key =
if let Some(key) = self.router.get_shared_secret_from_dest(dec_ip.into()) {
key
} else {
debug!("Can't decrypt OOB ICMP packet from unknown host");
continue;
};
let (_, body) = match etherparse::IpHeaders::from_slice(&real_packet[16..]) {
Ok(r) => r,
Err(e) => {
// This is a node which does not adhere to the protocol of sending back
// ICMP like this, or it is intentionally sending mallicious packets.
debug!(
"Dropping malformed OOB ICMP packet from {} for {e}",
data_packet.src_ip
);
continue;
}
};
let (header, body) = match etherparse::Icmpv6Header::from_slice(body.payload) {
Ok(r) => r,
Err(e) => {
// This is a node which does not adhere to the protocol of sending back
// ICMP like this, or it is intentionally sending mallicious packets.
debug!(
"Dropping OOB ICMP packet from {} with malformed ICMP header ({e})",
data_packet.src_ip
);
continue;
}
};
// Where are the leftover bytes coming from
let orig_pb = match key.decrypt(body[..body.len()].to_vec()) {
Ok(pb) => pb,
Err(e) => {
warn!("Failed to decrypt ICMP data body {e}");
continue;
}
};
let packet = etherparse::PacketBuilder::ipv6(
data_packet.src_ip.octets(),
data_packet.dst_ip.octets(),
data_packet.hop_limit,
)
.icmpv6(header.icmp_type);
let serialized_icmp = packet.size(orig_pb.len());
let mut rp = PacketBuffer::new();
rp.set_size(serialized_icmp);
if let Err(e) =
packet.write(&mut (&mut rp.buffer_mut()[..serialized_icmp]), &orig_pb)
{
error!("Could not reconstruct icmp packet {e}");
continue;
}
if let Err(e) = l3_packet_sink.send(rp).await {
error!("Failed to send packet on local TUN interface: {e}",);
continue;
}
}
_ => {
trace!("Dropping decrypted packet with unknown protocol type");
continue;
}
}
}
warn!("Extract loop from router to host ended");
}
}
impl<M> Clone for DataPlane<M>
where
M: Clone,
{
fn clone(&self) -> Self {
Self {
router: self.router.clone(),
}
}
}

116
mycelium/src/endpoint.rs Normal file
View File

@@ -0,0 +1,116 @@
use std::{
fmt,
net::{AddrParseError, SocketAddr},
str::FromStr,
};
use serde::{Deserialize, Serialize};
#[derive(Debug, Clone, PartialEq, Eq)]
/// Error generated while processing improperly formatted endpoints.
pub enum EndpointParseError {
/// An address was specified without leading protocol information.
MissingProtocol,
/// An endpoint was specified using a protocol we (currently) do not understand.
UnknownProtocol,
/// Error while parsing the specific address.
Address(AddrParseError),
}
/// Protocol used by an endpoint.
#[derive(Debug, Clone, Copy, PartialEq, Eq, PartialOrd, Ord, Hash, Serialize, Deserialize)]
#[serde(rename_all = "camelCase")]
pub enum Protocol {
/// Standard plain text Tcp.
Tcp,
/// Tls 1.3 with PSK over Tcp.
Tls,
/// Quic protocol (over UDP).
Quic,
}
/// An endpoint defines a address and a protocol to use when communicating with it.
#[derive(Debug, Clone, Copy, PartialEq, Eq, PartialOrd, Ord, Hash, Serialize, Deserialize)]
#[serde(rename_all = "camelCase")]
pub struct Endpoint {
proto: Protocol,
socket_addr: SocketAddr,
}
impl Endpoint {
/// Create a new `Endpoint` with given [`Protocol`] and address.
pub fn new(proto: Protocol, socket_addr: SocketAddr) -> Self {
Self { proto, socket_addr }
}
/// Get the [`Protocol`] used by this `Endpoint`.
pub fn proto(&self) -> Protocol {
self.proto
}
/// Get the [`SocketAddr`] used by this `Endpoint`.
pub fn address(&self) -> SocketAddr {
self.socket_addr
}
}
impl FromStr for Endpoint {
type Err = EndpointParseError;
fn from_str(s: &str) -> Result<Self, Self::Err> {
match s.split_once("://") {
None => Err(EndpointParseError::MissingProtocol),
Some((proto, socket)) => {
let proto = match proto.to_lowercase().as_str() {
"tcp" => Protocol::Tcp,
"quic" => Protocol::Quic,
"tls" => Protocol::Tls,
_ => return Err(EndpointParseError::UnknownProtocol),
};
let socket_addr = SocketAddr::from_str(socket)?;
Ok(Endpoint { proto, socket_addr })
}
}
}
}
impl fmt::Display for Endpoint {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
f.write_fmt(format_args!("{} {}", self.proto, self.socket_addr))
}
}
impl fmt::Display for Protocol {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
f.write_str(match self {
Self::Tcp => "Tcp",
Self::Tls => "Tls",
Self::Quic => "Quic",
})
}
}
impl fmt::Display for EndpointParseError {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
match self {
Self::MissingProtocol => f.write_str("missing leading protocol identifier"),
Self::UnknownProtocol => f.write_str("protocol for endpoint is not supported"),
Self::Address(e) => f.write_fmt(format_args!("failed to parse address: {e}")),
}
}
}
impl std::error::Error for EndpointParseError {
fn source(&self) -> Option<&(dyn std::error::Error + 'static)> {
match self {
Self::Address(e) => Some(e),
_ => None,
}
}
}
impl From<AddrParseError> for EndpointParseError {
fn from(value: AddrParseError) -> Self {
Self::Address(value)
}
}

53
mycelium/src/filters.rs Normal file
View File

@@ -0,0 +1,53 @@
use crate::{babel, subnet::Subnet};
/// This trait is used to filter incoming updates from peers. Only updates which pass all
/// configured filters on the local [`Router`](crate::router::Router) will actually be forwarded
/// to the [`Router`](crate::router::Router) for processing.
pub trait RouteUpdateFilter {
/// Judge an incoming update.
fn allow(&self, update: &babel::Update) -> bool;
}
/// Limit the subnet size of subnets announced in updates to be at most `N` bits. Note that "at
/// most" here means that the actual prefix length needs to be **AT LEAST** this value.
pub struct MaxSubnetSize<const N: u8>;
impl<const N: u8> RouteUpdateFilter for MaxSubnetSize<N> {
fn allow(&self, update: &babel::Update) -> bool {
update.subnet().prefix_len() >= N
}
}
/// Limit the subnet announced to be included in the given subnet.
pub struct AllowedSubnet {
subnet: Subnet,
}
impl AllowedSubnet {
/// Create a new `AllowedSubnet` filter, which only allows updates who's `Subnet` is contained
/// in the given `Subnet`.
pub fn new(subnet: Subnet) -> Self {
Self { subnet }
}
}
impl RouteUpdateFilter for AllowedSubnet {
fn allow(&self, update: &babel::Update) -> bool {
self.subnet.contains_subnet(&update.subnet())
}
}
/// Limit the announced subnets to those which contain the derived IP from the `RouterId`.
///
/// Since retractions can be sent by any node to indicate they don't have a route for the subnet,
/// these are also allowed.
pub struct RouterIdOwnsSubnet;
impl RouteUpdateFilter for RouterIdOwnsSubnet {
fn allow(&self, update: &babel::Update) -> bool {
update.metric().is_infinite()
|| update
.subnet()
.contains_ip(update.router_id().to_pubkey().address().into())
}
}

38
mycelium/src/interval.rs Normal file
View File

@@ -0,0 +1,38 @@
//! Dedicated logic for
//! [intervals](https://datatracker.ietf.org/doc/html/rfc8966#name-solving-starvation-sequenci).
use std::time::Duration;
/// An interval in the babel protocol.
///
/// Intervals represent a duration, and are expressed in centiseconds (0.01 second / 10
/// milliseconds). `Interval` implements [`From`] [`u16`] to create a new interval from a raw
/// value, and [`From`] [`Duration`] to create a new `Interval` from an existing [`Duration`].
/// There are also implementation to convert back to the aforementioned types. Note that in case of
/// duration, millisecond precision is lost.
#[derive(Debug, Clone)]
pub struct Interval(u16);
impl From<Duration> for Interval {
fn from(value: Duration) -> Self {
Interval((value.as_millis() / 10) as u16)
}
}
impl From<Interval> for Duration {
fn from(value: Interval) -> Self {
Duration::from_millis(value.0 as u64 * 10)
}
}
impl From<u16> for Interval {
fn from(value: u16) -> Self {
Interval(value)
}
}
impl From<Interval> for u16 {
fn from(value: Interval) -> Self {
value.0
}
}

462
mycelium/src/lib.rs Normal file
View File

@@ -0,0 +1,462 @@
use std::net::{IpAddr, Ipv6Addr};
use std::path::PathBuf;
#[cfg(feature = "message")]
use std::{future::Future, time::Duration};
use crate::cdn::Cdn;
use crate::tun::TunConfig;
use bytes::BytesMut;
use data::DataPlane;
use endpoint::Endpoint;
#[cfg(feature = "message")]
use message::TopicConfig;
#[cfg(feature = "message")]
use message::{
MessageId, MessageInfo, MessagePushResponse, MessageStack, PushMessageError, ReceivedMessage,
};
use metrics::Metrics;
use peer_manager::{PeerExists, PeerNotFound, PeerStats, PrivateNetworkKey};
use routing_table::{NoRouteSubnet, QueriedSubnet, RouteEntry};
use subnet::Subnet;
use tokio::net::TcpListener;
use tracing::{error, info, warn};
mod babel;
pub mod cdn;
mod connection;
pub mod crypto;
pub mod data;
pub mod endpoint;
pub mod filters;
mod interval;
#[cfg(feature = "message")]
pub mod message;
mod metric;
pub mod metrics;
pub mod packet;
mod peer;
pub mod peer_manager;
pub mod router;
mod router_id;
mod routing_table;
mod rr_cache;
mod seqno_cache;
mod sequence_number;
mod source_table;
pub mod subnet;
pub mod task;
mod tun;
/// The prefix of the global subnet used.
pub const GLOBAL_SUBNET_ADDRESS: IpAddr = IpAddr::V6(Ipv6Addr::new(0x400, 0, 0, 0, 0, 0, 0, 0));
/// The prefix length of the global subnet used.
pub const GLOBAL_SUBNET_PREFIX_LEN: u8 = 7;
/// Config for a mycelium [`Node`].
pub struct Config<M> {
/// The secret key of the node.
pub node_key: crypto::SecretKey,
/// Statically configured peers.
pub peers: Vec<Endpoint>,
/// Tun interface should be disabled.
pub no_tun: bool,
/// Listen port for TCP connections.
pub tcp_listen_port: u16,
/// Listen port for Quic connections.
pub quic_listen_port: Option<u16>,
/// Udp port for peer discovery.
pub peer_discovery_port: Option<u16>,
/// Name for the TUN device.
#[cfg(any(
target_os = "linux",
all(target_os = "macos", not(feature = "mactunfd")),
target_os = "windows"
))]
pub tun_name: String,
/// Configuration for a private network, if run in that mode. To enable private networking,
/// this must be a name + a PSK.
pub private_network_config: Option<(String, PrivateNetworkKey)>,
/// Implementation of the `Metrics` trait, used to expose information about the system
/// internals.
pub metrics: M,
/// Mark that's set on all packets that we send on the underlying network
pub firewall_mark: Option<u32>,
// tun_fd is android, iOS, macos on appstore specific option
// We can't create TUN device from the Rust code in android, iOS, and macos on appstore.
// So, we create the TUN device on Kotlin(android) or Swift(iOS, macos) then pass
// the TUN's file descriptor to mycelium.
#[cfg(any(
target_os = "android",
target_os = "ios",
all(target_os = "macos", feature = "mactunfd"),
))]
pub tun_fd: Option<i32>,
/// The maount of worker tasks spawned to process updates. Up to this amound of updates can be
/// processed in parallel. Because processing an update is a CPU bound task, it is pointless to
/// set this to a value which is higher than the amount of logical CPU cores available to the
/// system.
pub update_workers: usize,
pub cdn_cache: Option<PathBuf>,
/// Configuration for message topics, if this is not set the default config will be used.
#[cfg(feature = "message")]
pub topic_config: Option<TopicConfig>,
}
/// The Node is the main structure in mycelium. It governs the entire data flow.
pub struct Node<M> {
router: router::Router<M>,
peer_manager: peer_manager::PeerManager<M>,
_cdn: Option<Cdn>,
#[cfg(feature = "message")]
message_stack: message::MessageStack<M>,
}
/// General info about a node.
pub struct NodeInfo {
/// The overlay subnet in use by the node.
pub node_subnet: Subnet,
/// The public key of the node
pub node_pubkey: crypto::PublicKey,
}
impl<M> Node<M>
where
M: Metrics + Clone + Send + Sync + 'static,
{
/// Setup a new `Node` with the provided [`Config`].
pub async fn new(config: Config<M>) -> Result<Self, Box<dyn std::error::Error>> {
// If a private network is configured, validate network name
if let Some((net_name, _)) = &config.private_network_config {
if net_name.len() < 2 || net_name.len() > 64 {
return Err(std::io::Error::new(
std::io::ErrorKind::InvalidInput,
"network name must be between 2 and 64 characters",
)
.into());
}
}
let node_pub_key = crypto::PublicKey::from(&config.node_key);
let node_addr = node_pub_key.address();
let (tun_tx, tun_rx) = tokio::sync::mpsc::unbounded_channel();
let node_subnet = Subnet::new(
// Truncate last 64 bits of address.
// TODO: find a better way to do this.
Subnet::new(node_addr.into(), 64)
.expect("64 is a valid IPv6 prefix size; qed")
.network(),
64,
)
.expect("64 is a valid IPv6 prefix size; qed");
// Creating a new Router instance
let router = match router::Router::new(
config.update_workers,
tun_tx,
node_subnet,
vec![node_subnet],
(config.node_key, node_pub_key),
vec![
Box::new(filters::AllowedSubnet::new(
Subnet::new(GLOBAL_SUBNET_ADDRESS, GLOBAL_SUBNET_PREFIX_LEN)
.expect("Global subnet is properly defined; qed"),
)),
Box::new(filters::MaxSubnetSize::<64>),
Box::new(filters::RouterIdOwnsSubnet),
],
config.metrics.clone(),
) {
Ok(router) => {
info!(
"Router created. Pubkey: {:x}",
BytesMut::from(&router.node_public_key().as_bytes()[..])
);
router
}
Err(e) => {
error!("Error creating router: {e}");
panic!("Error creating router: {e}");
}
};
// Creating a new PeerManager instance
let pm = peer_manager::PeerManager::new(
router.clone(),
config.peers,
config.tcp_listen_port,
config.quic_listen_port,
config.peer_discovery_port.unwrap_or_default(),
config.peer_discovery_port.is_none(),
config.private_network_config,
config.metrics,
config.firewall_mark,
)?;
info!("Started peer manager");
#[cfg(feature = "message")]
let (tx, rx) = tokio::sync::mpsc::channel(100);
#[cfg(feature = "message")]
let msg_receiver = tokio_stream::wrappers::ReceiverStream::new(rx);
#[cfg(feature = "message")]
let msg_sender = tokio_util::sync::PollSender::new(tx);
#[cfg(not(feature = "message"))]
let msg_sender = futures::sink::drain();
let _data_plane = if config.no_tun {
warn!("Starting data plane without TUN interface, L3 functionality disabled");
DataPlane::new(
router.clone(),
// No tun so create a dummy stream for L3 packets which never yields
tokio_stream::pending(),
// Similarly, create a sink which just discards every packet we would receive
futures::sink::drain(),
msg_sender,
tun_rx,
)
} else {
#[cfg(not(any(
target_os = "linux",
target_os = "macos",
target_os = "windows",
target_os = "android",
target_os = "ios"
)))]
{
panic!("On this platform, you can only run with --no-tun");
}
#[cfg(any(
target_os = "linux",
target_os = "macos",
target_os = "windows",
target_os = "android",
target_os = "ios"
))]
{
#[cfg(any(
target_os = "linux",
all(target_os = "macos", not(feature = "mactunfd")),
target_os = "windows"
))]
let tun_config = TunConfig {
name: config.tun_name.clone(),
node_subnet: Subnet::new(node_addr.into(), 64)
.expect("64 is a valid subnet size for IPv6; qed"),
route_subnet: Subnet::new(GLOBAL_SUBNET_ADDRESS, GLOBAL_SUBNET_PREFIX_LEN)
.expect("Static configured TUN route is valid; qed"),
};
#[cfg(any(
target_os = "android",
target_os = "ios",
all(target_os = "macos", feature = "mactunfd"),
))]
let tun_config = TunConfig {
tun_fd: config.tun_fd.unwrap(),
};
let (rxhalf, txhalf) = tun::new(tun_config).await?;
info!("Node overlay IP: {node_addr}");
DataPlane::new(router.clone(), rxhalf, txhalf, msg_sender, tun_rx)
}
};
let cdn = config.cdn_cache.map(Cdn::new);
if let Some(ref cdn) = cdn {
let listener = TcpListener::bind("localhost:80").await?;
cdn.start(listener)?;
}
#[cfg(feature = "message")]
let ms = MessageStack::new(_data_plane, msg_receiver, config.topic_config);
Ok(Node {
router,
peer_manager: pm,
_cdn: cdn,
#[cfg(feature = "message")]
message_stack: ms,
})
}
/// Get information about the running `Node`
pub fn info(&self) -> NodeInfo {
NodeInfo {
node_subnet: self.router.node_tun_subnet(),
node_pubkey: self.router.node_public_key(),
}
}
/// Get information about the current peers in the `Node`
pub fn peer_info(&self) -> Vec<PeerStats> {
self.peer_manager.peers()
}
/// Add a new peer to the system identified by an [`Endpoint`].
pub fn add_peer(&self, endpoint: Endpoint) -> Result<(), PeerExists> {
self.peer_manager.add_peer(endpoint)
}
/// Remove an existing peer identified by an [`Endpoint`] from the system.
pub fn remove_peer(&self, endpoint: Endpoint) -> Result<(), PeerNotFound> {
self.peer_manager.delete_peer(&endpoint)
}
/// List all selected [`routes`](RouteEntry) in the system.
pub fn selected_routes(&self) -> Vec<RouteEntry> {
self.router.load_selected_routes()
}
/// List all fallback [`routes`](RouteEntry) in the system.
pub fn fallback_routes(&self) -> Vec<RouteEntry> {
self.router.load_fallback_routes()
}
/// List all [`queried subnets`](QueriedSubnet) in the system.
pub fn queried_subnets(&self) -> Vec<QueriedSubnet> {
self.router.load_queried_subnets()
}
/// List all [`subnets with no route`](NoRouteSubnet) in the system.
pub fn no_route_entries(&self) -> Vec<NoRouteSubnet> {
self.router.load_no_route_entries()
}
/// Get public key from the IP of `Node`
pub fn get_pubkey_from_ip(&self, ip: IpAddr) -> Option<crypto::PublicKey> {
self.router.get_pubkey(ip)
}
}
#[cfg(feature = "message")]
impl<M> Node<M>
where
M: Metrics + Clone + Send + 'static,
{
/// Wait for a messsage to arrive in the message stack.
///
/// An the optional `topic` is provided, only messages which have exactly the same value in
/// `topic` will be returned. The `pop` argument decides if the message is removed from the
/// internal queue or not. If `pop` is `false`, the same message will be returned on the next
/// call (with the same topic).
///
/// This method returns a future which will wait indefinitely until a message is received. It
/// is generally a good idea to put a limit on how long to wait by wrapping this in a [`tokio::time::timeout`].
pub fn get_message(
&self,
pop: bool,
topic: Option<Vec<u8>>,
) -> impl Future<Output = ReceivedMessage> + '_ {
// First reborrow only the message stack from self, then manually construct a future. This
// avoids a lifetime issue on the router, which is not sync. If a regular 'async' fn would
// be used here, we can't specify that at this point sadly.
let ms = &self.message_stack;
async move { ms.message(pop, topic).await }
}
/// Push a new message to the message stack.
///
/// The system will attempt to transmit the message for `try_duration`. A message is considered
/// transmitted when the receiver has indicated it completely received the message. If
/// `subscribe_reply` is `true`, the second return value will be [`Option::Some`], with a
/// watcher which will resolve if a reply for this exact message comes in. Since this relies on
/// the receiver actually sending a reply, ther is no guarantee that this will eventually
/// resolve.
pub fn push_message(
&self,
dst: IpAddr,
data: Vec<u8>,
topic: Option<Vec<u8>>,
try_duration: Duration,
subscribe_reply: bool,
) -> Result<MessagePushResponse, PushMessageError> {
self.message_stack.new_message(
dst,
data,
topic.unwrap_or_default(),
try_duration,
subscribe_reply,
)
}
/// Get the status of a message sent previously.
///
/// Returns [`Option::None`] if no message is found with the given id. Message info is only
/// retained for a limited time after a message has been received, or after the message has
/// been aborted due to a timeout.
pub fn message_status(&self, id: MessageId) -> Option<MessageInfo> {
self.message_stack.message_info(id)
}
/// Send a reply to a previously received message.
pub fn reply_message(
&self,
id: MessageId,
dst: IpAddr,
data: Vec<u8>,
try_duration: Duration,
) -> MessageId {
self.message_stack
.reply_message(id, dst, data, try_duration)
}
/// Get a list of all configured topics
pub fn topics(&self) -> Vec<Vec<u8>> {
self.message_stack.topics()
}
pub fn topic_allowed_sources(&self, topic: &Vec<u8>) -> Option<Vec<Subnet>> {
self.message_stack.topic_allowed_sources(topic)
}
/// Sets the default topic action to accept or reject. This decides how topics which don't have
/// an explicit whitelist get handled.
pub fn accept_unconfigured_topic(&self, accept: bool) {
self.message_stack.set_default_topic_action(accept)
}
/// Whether a topic without default configuration is accepted or not.
pub fn unconfigure_topic_action(&self) -> bool {
self.message_stack.get_default_topic_action()
}
/// Add a topic to the whitelist without any configured allowed sources.
pub fn add_topic_whitelist(&self, topic: Vec<u8>) {
self.message_stack.add_topic_whitelist(topic)
}
/// Remove a topic from the whitelist. Future messages will follow the default action.
pub fn remove_topic_whitelist(&self, topic: Vec<u8>) {
self.message_stack.remove_topic_whitelist(topic)
}
/// Add a new whitelisted source for a topic. This creates the topic if it does not exist yet.
pub fn add_topic_whitelist_src(&self, topic: Vec<u8>, src: Subnet) {
self.message_stack.add_topic_whitelist_src(topic, src)
}
/// Remove a whitelisted source for a topic.
pub fn remove_topic_whitelist_src(&self, topic: Vec<u8>, src: Subnet) {
self.message_stack.remove_topic_whitelist_src(topic, src)
}
/// Set the forward socket for a topic. Creates the topic if it doesn't exist.
pub fn set_topic_forward_socket(&self, topic: Vec<u8>, socket_path: std::path::PathBuf) {
self.message_stack
.set_topic_forward_socket(topic, Some(socket_path))
}
/// Get the forward socket for a topic, if any.
pub fn get_topic_forward_socket(&self, topic: &Vec<u8>) -> Option<std::path::PathBuf> {
self.message_stack.get_topic_forward_socket(topic)
}
/// Removes the forward socket for the topic, if one exists
pub fn delete_topic_forward_socket(&self, topic: Vec<u8>) {
self.message_stack.set_topic_forward_socket(topic, None)
}
}

1867
mycelium/src/message.rs Normal file

File diff suppressed because it is too large Load Diff

View File

@@ -0,0 +1,254 @@
use std::fmt;
use super::MessagePacket;
/// A message representing a "chunk" message.
///
/// The body of a chunk message has the following structure:
/// - 8 bytes: chunk index
/// - 8 bytes: chunk offset
/// - 8 bytes: chunk size
/// - remainder: chunk data of length based on field 3
pub struct MessageChunk {
buffer: MessagePacket,
}
impl MessageChunk {
/// Create a new `MessageChunk` in the provided [`MessagePacket`].
pub fn new(mut buffer: MessagePacket) -> Self {
buffer.set_used_buffer_size(24);
buffer.header_mut().flags_mut().set_chunk();
Self { buffer }
}
/// Return the index of the chunk in the message, as written in the body.
pub fn chunk_idx(&self) -> u64 {
u64::from_be_bytes(
self.buffer.buffer()[..8]
.try_into()
.expect("Buffer contains a size field of valid length; qed"),
)
}
/// Set the index of the chunk in the message body.
pub fn set_chunk_idx(&mut self, chunk_idx: u64) {
self.buffer.buffer_mut()[..8].copy_from_slice(&chunk_idx.to_be_bytes())
}
/// Return the chunk offset in the message, as written in the body.
pub fn chunk_offset(&self) -> u64 {
u64::from_be_bytes(
self.buffer.buffer()[8..16]
.try_into()
.expect("Buffer contains a size field of valid length; qed"),
)
}
/// Set the offset of the chunk in the message body.
pub fn set_chunk_offset(&mut self, chunk_offset: u64) {
self.buffer.buffer_mut()[8..16].copy_from_slice(&chunk_offset.to_be_bytes())
}
/// Return the size of the chunk in the message, as written in the body.
pub fn chunk_size(&self) -> u64 {
// Shield against a corrupt value.
u64::min(
u64::from_be_bytes(
self.buffer.buffer()[16..24]
.try_into()
.expect("Buffer contains a size field of valid length; qed"),
),
self.buffer.buffer().len() as u64 - 24,
)
}
/// Set the size of the chunk in the message body.
pub fn set_chunk_size(&mut self, chunk_size: u64) {
self.buffer.buffer_mut()[16..24].copy_from_slice(&chunk_size.to_be_bytes())
}
/// Return a reference to the chunk data in the message.
pub fn data(&self) -> &[u8] {
&self.buffer.buffer()[24..24 + self.chunk_size() as usize]
}
/// Set the chunk data in this message. This will also set the size field to the proper value.
pub fn set_chunk_data(&mut self, data: &[u8]) -> Result<(), InsufficientChunkSpace> {
let buf = self.buffer.buffer_mut();
let available_space = buf.len() - 24;
if data.len() > available_space {
return Err(InsufficientChunkSpace {
available: available_space,
needed: data.len(),
});
}
// Slicing based on data.len() is fine here as we just checked to make sure we can handle
// this capacity.
buf[24..24 + data.len()].copy_from_slice(data);
self.set_chunk_size(data.len() as u64);
// Also set the extra space used by the buffer on the underlying packet.
self.buffer.set_used_buffer_size(24 + data.len());
Ok(())
}
/// Convert the `MessageChunk` into a reply. This does nothing if it is already a reply.
pub fn into_reply(mut self) -> Self {
self.buffer.header_mut().flags_mut().set_ack();
// We want to leave the length field in tact but don't want to copy the data in the reply.
// This needs additional work on the underlying buffer.
// TODO
self
}
/// Consumes this `MessageChunk`, returning the underlying [`MessagePacket`].
pub fn into_inner(self) -> MessagePacket {
self.buffer
}
}
/// An error indicating not enough space is availbe in a message to set the chunk data.
#[derive(Debug)]
pub struct InsufficientChunkSpace {
/// Amount of space available in the chunk.
pub available: usize,
/// Amount of space needed to set the chunk data
pub needed: usize,
}
impl fmt::Display for InsufficientChunkSpace {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
write!(
f,
"Insufficient capacity available, needed {} bytes, have {} bytes",
self.needed, self.available
)
}
}
impl std::error::Error for InsufficientChunkSpace {}
#[cfg(test)]
mod tests {
use std::array;
use crate::{crypto::PacketBuffer, message::MessagePacket};
use super::MessageChunk;
#[test]
fn chunk_flag_set() {
let mc = MessageChunk::new(MessagePacket::new(PacketBuffer::new()));
let mp = mc.into_inner();
assert!(mp.header().flags().chunk());
}
#[test]
fn read_chunk_idx() {
let mut pb = PacketBuffer::new();
pb.buffer_mut()[12..20].copy_from_slice(&[0, 0, 0, 0, 0, 0, 100, 73]);
let ms = MessageChunk::new(MessagePacket::new(pb));
assert_eq!(ms.chunk_idx(), 25_673);
}
#[test]
fn write_chunk_idx() {
let mut ms = MessageChunk::new(MessagePacket::new(PacketBuffer::new()));
ms.set_chunk_idx(723);
// Since we don't work with packet buffer we don't have to account for the message packet
// header.
assert_eq!(&ms.buffer.buffer()[..8], &[0, 0, 0, 0, 0, 0, 2, 211]);
assert_eq!(ms.chunk_idx(), 723);
}
#[test]
fn read_chunk_offset() {
let mut pb = PacketBuffer::new();
pb.buffer_mut()[20..28].copy_from_slice(&[0, 0, 0, 0, 0, 20, 40, 60]);
let ms = MessageChunk::new(MessagePacket::new(pb));
assert_eq!(ms.chunk_offset(), 1_321_020);
}
#[test]
fn write_chunk_offset() {
let mut ms = MessageChunk::new(MessagePacket::new(PacketBuffer::new()));
ms.set_chunk_offset(1_000_000);
// Since we don't work with packet buffer we don't have to account for the message packet
// header.
assert_eq!(&ms.buffer.buffer()[8..16], &[0, 0, 0, 0, 0, 15, 66, 64]);
assert_eq!(ms.chunk_offset(), 1_000_000);
}
#[test]
fn read_chunk_size() {
let mut pb = PacketBuffer::new();
pb.buffer_mut()[28..36].copy_from_slice(&[0, 0, 0, 0, 0, 0, 3, 232]);
let ms = MessageChunk::new(MessagePacket::new(pb));
assert_eq!(ms.chunk_size(), 1_000);
}
#[test]
fn write_chunk_size() {
let mut ms = MessageChunk::new(MessagePacket::new(PacketBuffer::new()));
ms.set_chunk_size(1_300);
// Since we don't work with packet buffer we don't have to account for the message packet
// header.
assert_eq!(&ms.buffer.buffer()[16..24], &[0, 0, 0, 0, 0, 0, 5, 20]);
assert_eq!(ms.chunk_size(), 1_300);
}
#[test]
fn read_chunk_data() {
const CHUNK_DATA: &[u8] = &[1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16];
let mut pb = PacketBuffer::new();
// Set data len
pb.buffer_mut()[28..36].copy_from_slice(&CHUNK_DATA.len().to_be_bytes());
pb.buffer_mut()[36..36 + CHUNK_DATA.len()].copy_from_slice(CHUNK_DATA);
let ms = MessageChunk::new(MessagePacket::new(pb));
assert_eq!(ms.chunk_size(), 16);
assert_eq!(ms.data(), CHUNK_DATA);
}
#[test]
fn write_chunk_data() {
const CHUNK_DATA: &[u8] = &[1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16];
let mut ms = MessageChunk::new(MessagePacket::new(PacketBuffer::new()));
let res = ms.set_chunk_data(CHUNK_DATA);
assert!(res.is_ok());
// Since we don't work with packet buffer we don't have to account for the message packet
// header.
// Check and make sure size is properly set.
assert_eq!(&ms.buffer.buffer()[16..24], &[0, 0, 0, 0, 0, 0, 0, 16]);
assert_eq!(ms.chunk_size(), 16);
assert_eq!(ms.data(), CHUNK_DATA);
}
#[test]
fn write_chunk_data_oversized() {
let data: [u8; 1500] = array::from_fn(|_| 0xFF);
let mut ms = MessageChunk::new(MessagePacket::new(PacketBuffer::new()));
let res = ms.set_chunk_data(&data);
assert!(res.is_err());
}
}

View File

@@ -0,0 +1,131 @@
use super::{MessageChecksum, MessagePacket, MESSAGE_CHECKSUM_LENGTH};
/// A message representing a "done" message.
///
/// The body of a done message has the following structure:
/// - 8 bytes: chunks transmitted
/// - 32 bytes: checksum of the transmitted data
pub struct MessageDone {
buffer: MessagePacket,
}
impl MessageDone {
/// Create a new `MessageDone` in the provided [`MessagePacket`].
pub fn new(mut buffer: MessagePacket) -> Self {
buffer.set_used_buffer_size(40);
buffer.header_mut().flags_mut().set_done();
Self { buffer }
}
/// Return the amount of chunks in the message, as written in the body.
pub fn chunk_count(&self) -> u64 {
u64::from_be_bytes(
self.buffer.buffer()[..8]
.try_into()
.expect("Buffer contains a size field of valid length; qed"),
)
}
/// Set the amount of chunks field of the message body.
pub fn set_chunk_count(&mut self, chunk_count: u64) {
self.buffer.buffer_mut()[..8].copy_from_slice(&chunk_count.to_be_bytes())
}
/// Get the checksum of the message from the body.
pub fn checksum(&self) -> MessageChecksum {
MessageChecksum::from_bytes(
self.buffer.buffer()[8..8 + MESSAGE_CHECKSUM_LENGTH]
.try_into()
.expect("Buffer contains enough data for a checksum; qed"),
)
}
/// Set the checksum of the message in the body.
pub fn set_checksum(&mut self, checksum: MessageChecksum) {
self.buffer.buffer_mut()[8..8 + MESSAGE_CHECKSUM_LENGTH]
.copy_from_slice(checksum.as_bytes())
}
/// Convert the `MessageDone` into a reply. This does nothing if it is already a reply.
pub fn into_reply(mut self) -> Self {
self.buffer.header_mut().flags_mut().set_ack();
self
}
/// Consumes this `MessageDone`, returning the underlying [`MessagePacket`].
pub fn into_inner(self) -> MessagePacket {
self.buffer
}
}
#[cfg(test)]
mod tests {
use crate::{
crypto::PacketBuffer,
message::{MessageChecksum, MessagePacket},
};
use super::MessageDone;
#[test]
fn done_flag_set() {
let md = MessageDone::new(MessagePacket::new(PacketBuffer::new()));
let mp = md.into_inner();
assert!(mp.header().flags().done());
}
#[test]
fn read_chunk_count() {
let mut pb = PacketBuffer::new();
pb.buffer_mut()[12..20].copy_from_slice(&[0, 0, 0, 0, 0, 0, 73, 55]);
let ms = MessageDone::new(MessagePacket::new(pb));
assert_eq!(ms.chunk_count(), 18_743);
}
#[test]
fn write_chunk_count() {
let mut ms = MessageDone::new(MessagePacket::new(PacketBuffer::new()));
ms.set_chunk_count(10_000);
// Since we don't work with packet buffer we don't have to account for the message packet
// header.
assert_eq!(&ms.buffer.buffer()[..8], &[0, 0, 0, 0, 0, 0, 39, 16]);
assert_eq!(ms.chunk_count(), 10_000);
}
#[test]
fn read_checksum() {
const CHECKSUM: MessageChecksum = MessageChecksum::from_bytes([
0x00, 0x01, 0x02, 0x03, 0x04, 0x05, 0x06, 0x07, 0x08, 0x09, 0x0A, 0x0B, 0x0C, 0x0D,
0x0E, 0x0F, 0x10, 0x11, 0x12, 0x13, 0x14, 0x15, 0x16, 0x17, 0x18, 0x19, 0x1A, 0x1B,
0x1C, 0x1D, 0x1E, 0x1F,
]);
let mut pb = PacketBuffer::new();
pb.buffer_mut()[20..52].copy_from_slice(CHECKSUM.as_bytes());
let ms = MessageDone::new(MessagePacket::new(pb));
assert_eq!(ms.checksum(), CHECKSUM);
}
#[test]
fn write_checksum() {
const CHECKSUM: MessageChecksum = MessageChecksum::from_bytes([
0x00, 0x01, 0x02, 0x03, 0x04, 0x05, 0x06, 0x07, 0x08, 0x09, 0x0A, 0x0B, 0x0C, 0x0D,
0x0E, 0x0F, 0x10, 0x11, 0x12, 0x13, 0x14, 0x15, 0x16, 0x17, 0x18, 0x19, 0x1A, 0x1B,
0x1C, 0x1D, 0x1E, 0x1F,
]);
let mut ms = MessageDone::new(MessagePacket::new(PacketBuffer::new()));
ms.set_checksum(CHECKSUM);
// Since we don't work with packet buffer we don't have to account for the message packet
// header.
assert_eq!(&ms.buffer.buffer()[8..40], CHECKSUM.as_bytes());
assert_eq!(ms.checksum(), CHECKSUM);
}
}

View File

@@ -0,0 +1,101 @@
use super::MessagePacket;
/// A message representing an init message.
///
/// The body of an init message has the following structure:
/// - 8 bytes size
pub struct MessageInit {
buffer: MessagePacket,
}
impl MessageInit {
/// Create a new `MessageInit` in the provided [`MessagePacket`].
pub fn new(mut buffer: MessagePacket) -> Self {
buffer.set_used_buffer_size(9);
buffer.header_mut().flags_mut().set_init();
Self { buffer }
}
/// Return the length of the message, as written in the body.
pub fn length(&self) -> u64 {
u64::from_be_bytes(
self.buffer.buffer()[..8]
.try_into()
.expect("Buffer contains a size field of valid length; qed"),
)
}
/// Return the topic of the message, as written in the body.
pub fn topic(&self) -> &[u8] {
let topic_len = self.buffer.buffer()[8] as usize;
&self.buffer.buffer()[9..9 + topic_len]
}
/// Set the length field of the message body.
pub fn set_length(&mut self, length: u64) {
self.buffer.buffer_mut()[..8].copy_from_slice(&length.to_be_bytes())
}
/// Set the topic in the message body.
///
/// # Panics
///
/// This function panics if the topic is longer than 255 bytes.
pub fn set_topic(&mut self, topic: &[u8]) {
assert!(
topic.len() <= u8::MAX as usize,
"Topic can be 255 bytes long at most"
);
self.buffer.set_used_buffer_size(9 + topic.len());
self.buffer.buffer_mut()[8] = topic.len() as u8;
self.buffer.buffer_mut()[9..9 + topic.len()].copy_from_slice(topic);
}
/// Convert the `MessageInit` into a reply. This does nothing if it is already a reply.
pub fn into_reply(mut self) -> Self {
self.buffer.header_mut().flags_mut().set_ack();
self
}
/// Consumes this `MessageInit`, returning the underlying [`MessagePacket`].
pub fn into_inner(self) -> MessagePacket {
self.buffer
}
}
#[cfg(test)]
mod tests {
use crate::{crypto::PacketBuffer, message::MessagePacket};
use super::MessageInit;
#[test]
fn init_flag_set() {
let mi = MessageInit::new(MessagePacket::new(PacketBuffer::new()));
let mp = mi.into_inner();
assert!(mp.header().flags().init());
}
#[test]
fn read_length() {
let mut pb = PacketBuffer::new();
pb.buffer_mut()[12..20].copy_from_slice(&[0, 0, 0, 0, 2, 3, 4, 5]);
let ms = MessageInit::new(MessagePacket::new(pb));
assert_eq!(ms.length(), 33_752_069);
}
#[test]
fn write_length() {
let mut ms = MessageInit::new(MessagePacket::new(PacketBuffer::new()));
ms.set_length(3_432_634_632);
// Since we don't work with packet buffer we don't have to account for the message packet
// header.
assert_eq!(&ms.buffer.buffer()[..8], &[0, 0, 0, 0, 204, 153, 217, 8]);
assert_eq!(ms.length(), 3_432_634_632);
}
}

View File

@@ -0,0 +1,230 @@
use crate::subnet::Subnet;
use core::fmt;
use serde::{
de::{Deserialize, Deserializer, MapAccess, Visitor},
Deserialize as DeserializeMacro,
};
use std::collections::HashMap;
use std::path::PathBuf;
/// Configuration for a topic whitelist, including allowed subnets and optional forward socket
#[derive(Debug, Default, Clone)]
pub struct TopicWhitelistConfig {
/// Subnets that are allowed to send messages to this topic
subnets: Vec<Subnet>,
/// Optional Unix domain socket path to forward messages to
forward_socket: Option<PathBuf>,
}
impl TopicWhitelistConfig {
/// Create a new empty whitelist config
pub fn new() -> Self {
Self::default()
}
/// Get the list of whitelisted subnets
pub fn subnets(&self) -> &Vec<Subnet> {
&self.subnets
}
/// Get the forward socket path, if any
pub fn forward_socket(&self) -> Option<&PathBuf> {
self.forward_socket.as_ref()
}
/// Set the forward socket path
pub fn set_forward_socket(&mut self, path: Option<PathBuf>) {
self.forward_socket = path;
}
/// Add a subnet to the whitelist
pub fn add_subnet(&mut self, subnet: Subnet) {
self.subnets.push(subnet);
}
/// Remove a subnet from the whitelist
pub fn remove_subnet(&mut self, subnet: &Subnet) {
self.subnets.retain(|s| s != subnet);
}
}
#[derive(Debug, Default, Clone)]
pub struct TopicConfig {
/// The default action to to take if no acl is defined for a topic.
default: MessageAction,
/// Explicitly configured whitelists for topics. Ip's which aren't part of the whitelist will
/// not be allowed to send messages to that topic. If a topic is not in this map, the default
/// action will be used.
whitelist: HashMap<Vec<u8>, TopicWhitelistConfig>,
}
impl TopicConfig {
/// Get the [`default action`](MessageAction) if the topic is not configured.
pub fn default(&self) -> MessageAction {
self.default
}
/// Set the default [`action`](MessageAction) which does not have a whitelist configured.
pub fn set_default(&mut self, default: MessageAction) {
self.default = default;
}
/// Get the fully configured whitelist
pub fn whitelist(&self) -> &HashMap<Vec<u8>, TopicWhitelistConfig> {
&self.whitelist
}
/// Insert a new topic in the whitelist, without any configured allowed sources.
pub fn add_topic_whitelist(&mut self, topic: Vec<u8>) {
self.whitelist.entry(topic).or_default();
}
/// Set the forward socket for a topic. Does nothing if the topic doesn't exist.
pub fn set_topic_forward_socket(&mut self, topic: Vec<u8>, socket_path: Option<PathBuf>) {
self.whitelist
.entry(topic)
.and_modify(|c| c.set_forward_socket(socket_path));
}
/// Get the forward socket for a topic, if any.
pub fn get_topic_forward_socket(&self, topic: &Vec<u8>) -> Option<&PathBuf> {
self.whitelist
.get(topic)
.and_then(|config| config.forward_socket())
}
/// Remove a topic from the whitelist. Future messages will follow the default action.
pub fn remove_topic_whitelist(&mut self, topic: &Vec<u8>) {
self.whitelist.remove(topic);
}
/// Adds a new whitelisted source for a topic. This creates the topic if it does not exist yet.
pub fn add_topic_whitelist_src(&mut self, topic: Vec<u8>, src: Subnet) {
self.whitelist.entry(topic).or_default().add_subnet(src);
}
/// Removes a whitelisted source for a topic.
///
/// If the last source is removed for a topic, the entry remains, and must be cleared by calling
/// [`Self::remove_topic_whitelist`] to fall back to the default action. Note that an empty
/// whitelist effectively blocks all messages for a topic.
///
/// This does nothing if the topic does not exist.
pub fn remove_topic_whitelist_src(&mut self, topic: &Vec<u8>, src: Subnet) {
if let Some(whitelist_config) = self.whitelist.get_mut(topic) {
whitelist_config.remove_subnet(&src);
}
}
}
#[derive(Debug, Default, Clone, Copy, DeserializeMacro)]
pub enum MessageAction {
/// Accept the message
#[default]
Accept,
/// Reject the message
Reject,
}
// Helper function to parse a subnet from a string
fn parse_subnet_str<E>(s: &str) -> Result<Subnet, E>
where
E: serde::de::Error,
{
// Try to parse as a subnet (with prefix)
if let Ok(ipnet) = s.parse::<ipnet::IpNet>() {
return Subnet::new(ipnet.addr(), ipnet.prefix_len())
.map_err(|e| serde::de::Error::custom(format!("Invalid subnet prefix length: {e}")));
}
// Try to parse as an IP address (convert to /32 or /128 subnet)
if let Ok(ip) = s.parse::<std::net::IpAddr>() {
let prefix_len = match ip {
std::net::IpAddr::V4(_) => 32,
std::net::IpAddr::V6(_) => 128,
};
return Subnet::new(ip, prefix_len)
.map_err(|e| serde::de::Error::custom(format!("Invalid subnet prefix length: {e}")));
}
Err(serde::de::Error::custom(format!(
"Invalid subnet or IP address: {s}",
)))
}
// Define a struct for deserializing the whitelist config
#[derive(DeserializeMacro)]
struct WhitelistConfigData {
#[serde(default)]
subnets: Vec<String>,
#[serde(default)]
forward_socket: Option<String>,
}
// Add this implementation right after the TopicConfig struct definition
impl<'de> Deserialize<'de> for TopicConfig {
fn deserialize<D>(deserializer: D) -> Result<Self, D::Error>
where
D: Deserializer<'de>,
{
struct TopicConfigVisitor;
impl<'de> Visitor<'de> for TopicConfigVisitor {
type Value = TopicConfig;
fn expecting(&self, formatter: &mut fmt::Formatter) -> fmt::Result {
formatter.write_str("a topic configuration")
}
fn visit_map<V>(self, mut map: V) -> Result<TopicConfig, V::Error>
where
V: MapAccess<'de>,
{
let mut default = MessageAction::default();
let mut whitelist = HashMap::new();
while let Some(key) = map.next_key::<String>()? {
if key == "default" {
default = map.next_value()?;
} else {
// Try to parse as a WhitelistConfigData first
if let Ok(config_data) = map.next_value::<WhitelistConfigData>() {
let mut whitelist_config = TopicWhitelistConfig::default();
// Process subnets
for subnet_str in config_data.subnets {
let subnet = parse_subnet_str(&subnet_str)?;
whitelist_config.add_subnet(subnet);
}
// Process forward_socket
if let Some(socket_path) = config_data.forward_socket {
whitelist_config
.set_forward_socket(Some(PathBuf::from(socket_path)));
}
// Convert string key to Vec<u8>
whitelist.insert(key.into_bytes(), whitelist_config);
} else {
// Fallback to old format: just a list of subnets
let subnet_strs = map.next_value::<Vec<String>>()?;
let mut whitelist_config = TopicWhitelistConfig::default();
for subnet_str in subnet_strs {
let subnet = parse_subnet_str(&subnet_str)?;
whitelist_config.add_subnet(subnet);
}
// Convert string key to Vec<u8>
whitelist.insert(key.into_bytes(), whitelist_config);
}
}
}
Ok(TopicConfig { default, whitelist })
}
}
deserializer.deserialize_map(TopicConfigVisitor)
}
}

144
mycelium/src/metric.rs Normal file
View File

@@ -0,0 +1,144 @@
//! Dedicated logic for
//! [metrics](https://datatracker.ietf.org/doc/html/rfc8966#metric-computation).
use core::fmt;
use std::ops::{Add, Sub};
/// Value of the infinite metric.
const METRIC_INFINITE: u16 = 0xFFFF;
/// A `Metric` is used to indicate the cost associated with a route. A lower Metric means a route
/// is more favorable.
#[derive(Debug, Clone, Copy, PartialEq, Eq, Ord, PartialOrd)]
pub struct Metric(u16);
impl Metric {
/// Create a new `Metric` with the given value.
pub const fn new(value: u16) -> Self {
Metric(value)
}
/// Creates a new infinite `Metric`.
pub const fn infinite() -> Self {
Metric(METRIC_INFINITE)
}
/// Checks if this metric indicates a retracted route.
pub const fn is_infinite(&self) -> bool {
self.0 == METRIC_INFINITE
}
/// Checks if this metric represents a directly connected route.
pub const fn is_direct(&self) -> bool {
self.0 == 0
}
/// Computes the absolute value of the difference between this and another `Metric`.
pub fn delta(&self, rhs: &Self) -> Metric {
Metric(if self > rhs {
self.0 - rhs.0
} else {
rhs.0 - self.0
})
}
}
impl fmt::Display for Metric {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
if self.is_infinite() {
f.pad("Infinite")
} else {
f.write_fmt(format_args!("{}", self.0))
}
}
}
impl From<u16> for Metric {
fn from(value: u16) -> Self {
Metric(value)
}
}
impl From<Metric> for u16 {
fn from(value: Metric) -> Self {
value.0
}
}
impl Add for Metric {
type Output = Self;
fn add(self, rhs: Metric) -> Self::Output {
if self.is_infinite() || rhs.is_infinite() {
return Metric::infinite();
}
Metric(
self.0
.checked_add(rhs.0)
.map(|r| if r == u16::MAX { r - 1 } else { r })
.unwrap_or(u16::MAX - 1),
)
}
}
impl Add<&Metric> for &Metric {
type Output = Metric;
fn add(self, rhs: &Metric) -> Self::Output {
if self.is_infinite() || rhs.is_infinite() {
return Metric::infinite();
}
Metric(
self.0
.checked_add(rhs.0)
.map(|r| if r == u16::MAX { r - 1 } else { r })
.unwrap_or(u16::MAX - 1),
)
}
}
impl Add<&Metric> for Metric {
type Output = Self;
fn add(self, rhs: &Metric) -> Self::Output {
if self.is_infinite() || rhs.is_infinite() {
return Metric::infinite();
}
Metric(
self.0
.checked_add(rhs.0)
.map(|r| if r == u16::MAX { r - 1 } else { r })
.unwrap_or(u16::MAX - 1),
)
}
}
impl Add<Metric> for &Metric {
type Output = Metric;
fn add(self, rhs: Metric) -> Self::Output {
if self.is_infinite() || rhs.is_infinite() {
return Metric::infinite();
}
Metric(
self.0
.checked_add(rhs.0)
.map(|r| if r == u16::MAX { r - 1 } else { r })
.unwrap_or(u16::MAX - 1),
)
}
}
impl Sub<Metric> for Metric {
type Output = Metric;
fn sub(self, rhs: Metric) -> Self::Output {
if rhs.is_infinite() {
panic!("Can't subtract an infinite metric");
}
if self.is_infinite() {
return Metric::infinite();
}
Metric(self.0.saturating_sub(rhs.0))
}
}

195
mycelium/src/metrics.rs Normal file
View File

@@ -0,0 +1,195 @@
//! This module is used for collection of runtime metrics of a `mycelium` system. The main item of
//! interest is the [`Metrics`] trait. Users can provide their own implementation of this, or use
//! the default provided implementation to disable gathering metrics.
use crate::peer_manager::PeerType;
/// The collection of all metrics exported by a [`mycelium node`](crate::Node). It is up to the
/// user to provide an implementation which implements the methods for metrics they are interested
/// in. All methods have a default implementation, so if the user is not interested in any metrics,
/// a NOOP handler can be implemented as follows:
///
/// ```rust
/// use mycelium::metrics::Metrics;
///
/// #[derive(Clone)]
/// struct NoMetrics;
/// impl Metrics for NoMetrics {}
/// ```
pub trait Metrics {
/// The [`Router`](crate::router::Router) received a new Hello TLV from a peer.
#[inline]
fn router_process_hello(&self) {}
/// The [`Router`](crate::router::Router) received a new IHU TLV from a peer.
#[inline]
fn router_process_ihu(&self) {}
/// The [`Router`](crate::router::Router) received a new Seqno request TLV from a peer.
#[inline]
fn router_process_seqno_request(&self) {}
/// The [`Router`](crate::router::Router) received a new Route request TLV from a peer.
/// Additionally, it is recorded if this is a wildcard request (route table dump request)
/// or a request for a specific subnet.
#[inline]
fn router_process_route_request(&self, _wildcard: bool) {}
/// The [`Router`](crate::router::Router) received a new Update TLV from a peer.
#[inline]
fn router_process_update(&self) {}
/// The [`Router`](crate::router::Router) tried to send an update to a peer, but before sending
/// it we found out the peer is actually already dead.
///
/// This can happen, since a peer is a remote entity we have no control over, and it can be
/// removed at any time for any reason. However, in normal operation, the amount of times this
/// happens should be fairly small compared to the amount of updates we send/receive.
#[inline]
fn router_update_dead_peer(&self) {}
/// The amount of TLV's received from peers, to be processed by the
/// [`Router`](crate::router::Router).
#[inline]
fn router_received_tlv(&self) {}
/// The [`Router`](crate::router::Router) dropped a received TLV before processing it, as the
/// peer who sent it has already died in the meantime.
#[inline]
fn router_tlv_source_died(&self) {}
/// The [`Router`](crate::router::Router) dropped a received TLV before processing it, because
/// it coulnd't keep up
#[inline]
fn router_tlv_discarded(&self) {}
/// A [`Peer`](crate::peer::Peer) was added to the [`Router`](crate::router::Router).
#[inline]
fn router_peer_added(&self) {}
/// A [`Peer`](crate::peer::Peer) was removed from the [`Router`](crate::router::Router).
#[inline]
fn router_peer_removed(&self) {}
/// A [`Peer`](crate::peer::Peer) informed the [`Router`](crate::router::Router) it died, or
/// the router otherwise noticed the Peer is dead.
#[inline]
fn router_peer_died(&self) {}
/// The [`Router`](crate::router::Router) ran a route selection procedure.
#[inline]
fn router_route_selection_ran(&self) {}
/// A [`SourceKey`](crate::source_table::SourceKey) expired and got cleaned up by the [`Router`](crate::router::Router).
#[inline]
fn router_source_key_expired(&self) {}
/// A [`RouteKey`](crate::routing_table::RouteKey) expired, and the router either set the
/// [`Metric`](crate::metric::Metric) of the route to infinity, or cleaned up the route entry
/// altogether.
#[inline]
fn router_route_key_expired(&self, _removed: bool) {}
/// A route which expired was actually the selected route for the
/// [`Subnet`](crate::subnet::Subnet). Note that [`Self::router_route_key_expired`] will
/// also have been called.
#[inline]
fn router_selected_route_expired(&self) {}
/// The [`Router`](crate::router::Router) sends a "triggered" update to it's peers.
#[inline]
fn router_triggered_update(&self) {}
/// The [`Router`](crate::router::Router) extracted a packet for the local subnet.
#[inline]
fn router_route_packet_local(&self) {}
/// The [`Router`](crate::router::Router) forwarded a packet to a peer.
#[inline]
fn router_route_packet_forward(&self) {}
/// The [`Router`](crate::router::Router) dropped a packet it was routing because it's TTL
/// reached 0.
#[inline]
fn router_route_packet_ttl_expired(&self) {}
/// The [`Router`](crate::router::Router) dropped a packet it was routing because there was no
/// route for the destination IP.
#[inline]
fn router_route_packet_no_route(&self) {}
/// The [`Router`](crate::router::Router) replied to a seqno request with a local route, which
/// is more recent (bigger seqno) than the request.
#[inline]
fn router_seqno_request_reply_local(&self) {}
/// The [`Router`](crate::router::Router) replied to a seqno request by bumping its own seqno
/// and advertising the local route.
#[inline]
fn router_seqno_request_bump_seqno(&self) {}
/// The [`Router`](crate::router::Router) dropped a seqno request because the TTL reached 0.
#[inline]
fn router_seqno_request_dropped_ttl(&self) {}
/// The [`Router`](crate::router::Router) forwarded a seqno request to a feasible route.
#[inline]
fn router_seqno_request_forward_feasible(&self) {}
/// The [`Router`](crate::router::Router) forwarded a seqno request to a (potentially)
/// unfeasible route.
#[inline]
fn router_seqno_request_forward_unfeasible(&self) {}
/// The [`Router`](crate::router::Router) dropped a seqno request becase none of the other
/// handling methods applied.
#[inline]
fn router_seqno_request_unhandled(&self) {}
/// The [`time`](std::time::Duration) used by the [`Router`](crate::router::Router) to handle a
/// control packet.
#[inline]
fn router_time_spent_handling_tlv(&self, _duration: std::time::Duration, _tlv_type: &str) {}
/// The [`time`](std::time::Duration) used by the [`Router`](crate::router::Router) to
/// periodically propagate selected routes to peers.
#[inline]
fn router_time_spent_periodic_propagating_selected_routes(
&self,
_duration: std::time::Duration,
) {
}
/// An update was processed and accepted by the router, but did not run route selection.
#[inline]
fn router_update_skipped_route_selection(&self) {}
/// An update was denied by a configured filter.
#[inline]
fn router_update_denied_by_filter(&self) {}
/// An update was accepted by the router filters, but was otherwise unfeasible or a retraction,
/// for an unknown subnet.
#[inline]
fn router_update_not_interested(&self) {}
/// A new [`Peer`](crate::peer::Peer) was added to the
/// [`PeerManager`](crate::peer_manager::PeerManager) while it is running.
#[inline]
fn peer_manager_peer_added(&self, _pt: PeerType) {}
/// Sets the amount of [`Peers`](crate::peer::Peer) known by the
/// [`PeerManager`](crate::peer_manager::PeerManager).
#[inline]
fn peer_manager_known_peers(&self, _amount: usize) {}
/// The [`PeerManager`](crate::peer_manager::PeerManager) started an attempt to connect to a
/// remote endpoint.
#[inline]
fn peer_manager_connection_attempted(&self) {}
/// The [`PeerManager`](crate::peer_manager::PeerManager) finished an attempt to connect to a
/// remote endpoint. The connection could have failed.
#[inline]
fn peer_manager_connection_finished(&self) {}
}

134
mycelium/src/packet.rs Normal file
View File

@@ -0,0 +1,134 @@
use bytes::{Buf, BufMut, BytesMut};
pub use control::ControlPacket;
pub use data::DataPacket;
use tokio_util::codec::{Decoder, Encoder};
mod control;
mod data;
/// Current version of the protocol being used.
const PROTOCOL_VERSION: u8 = 1;
/// The size of a `Packet` header on the wire, in bytes.
const PACKET_HEADER_SIZE: usize = 4;
#[derive(Debug, Clone)]
pub enum Packet {
DataPacket(DataPacket),
ControlPacket(ControlPacket),
}
#[derive(Debug, Clone, Copy)]
#[repr(u8)]
pub enum PacketType {
DataPacket = 0,
ControlPacket = 1,
}
pub struct Codec {
packet_type: Option<PacketType>,
data_packet_codec: data::Codec,
control_packet_codec: control::Codec,
}
impl Codec {
pub fn new() -> Self {
Codec {
packet_type: None,
data_packet_codec: data::Codec::new(),
control_packet_codec: control::Codec::new(),
}
}
}
impl Decoder for Codec {
type Item = Packet;
type Error = std::io::Error;
fn decode(&mut self, src: &mut BytesMut) -> Result<Option<Self::Item>, Self::Error> {
// Determine the packet_type
let packet_type = if let Some(packet_type) = self.packet_type {
packet_type
} else {
// Check we can read the header
if src.remaining() <= PACKET_HEADER_SIZE {
return Ok(None);
}
let mut header = [0; PACKET_HEADER_SIZE];
header.copy_from_slice(&src[..PACKET_HEADER_SIZE]);
src.advance(PACKET_HEADER_SIZE);
// For now it's a hard error to not follow the 1 defined protocol version
if header[0] != PROTOCOL_VERSION {
return Err(std::io::Error::new(
std::io::ErrorKind::InvalidData,
"Unknown protocol version",
));
};
let packet_type_byte = header[1];
let packet_type = match packet_type_byte {
0 => PacketType::DataPacket,
1 => PacketType::ControlPacket,
_ => {
return Err(std::io::Error::new(
std::io::ErrorKind::InvalidData,
"Invalid packet type",
));
}
};
self.packet_type = Some(packet_type);
packet_type
};
// Decode packet based on determined packet_type
match packet_type {
PacketType::DataPacket => {
match self.data_packet_codec.decode(src) {
Ok(Some(p)) => {
self.packet_type = None; // Reset state
Ok(Some(Packet::DataPacket(p)))
}
Ok(None) => Ok(None),
Err(e) => Err(e),
}
}
PacketType::ControlPacket => {
match self.control_packet_codec.decode(src) {
Ok(Some(p)) => {
self.packet_type = None; // Reset state
Ok(Some(Packet::ControlPacket(p)))
}
Ok(None) => Ok(None),
Err(e) => Err(e),
}
}
}
}
}
impl Encoder<Packet> for Codec {
type Error = std::io::Error;
fn encode(&mut self, item: Packet, dst: &mut BytesMut) -> Result<(), Self::Error> {
match item {
Packet::DataPacket(datapacket) => {
dst.put_slice(&[PROTOCOL_VERSION, 0, 0, 0]);
self.data_packet_codec.encode(datapacket, dst)
}
Packet::ControlPacket(controlpacket) => {
dst.put_slice(&[PROTOCOL_VERSION, 1, 0, 0]);
self.control_packet_codec.encode(controlpacket, dst)
}
}
}
}
impl Default for Codec {
fn default() -> Self {
Self::new()
}
}

View File

@@ -0,0 +1,64 @@
use std::{io, net::IpAddr, time::Duration};
use bytes::BytesMut;
use tokio_util::codec::{Decoder, Encoder};
use crate::{
babel, metric::Metric, peer::Peer, router_id::RouterId, sequence_number::SeqNo, subnet::Subnet,
};
pub type ControlPacket = babel::Tlv;
pub struct Codec {
// TODO: wrapper to make it easier to deserialize
codec: babel::Codec,
}
impl ControlPacket {
pub fn new_hello(dest_peer: &Peer, interval: Duration) -> Self {
let tlv: babel::Tlv =
babel::Hello::new_unicast(dest_peer.hello_seqno(), (interval.as_millis() / 10) as u16)
.into();
dest_peer.increment_hello_seqno();
tlv
}
pub fn new_ihu(rx_cost: Metric, interval: Duration, dest_address: Option<IpAddr>) -> Self {
babel::Ihu::new(rx_cost, (interval.as_millis() / 10) as u16, dest_address).into()
}
pub fn new_update(
interval: Duration,
seqno: SeqNo,
metric: Metric,
subnet: Subnet,
router_id: RouterId,
) -> Self {
babel::Update::new(interval, seqno, metric, subnet, router_id).into()
}
}
impl Codec {
pub fn new() -> Self {
Codec {
codec: babel::Codec::new(),
}
}
}
impl Decoder for Codec {
type Item = ControlPacket;
type Error = std::io::Error;
fn decode(&mut self, buf: &mut BytesMut) -> Result<Option<Self::Item>, Self::Error> {
self.codec.decode(buf)
}
}
impl Encoder<ControlPacket> for Codec {
type Error = io::Error;
fn encode(&mut self, message: ControlPacket, buf: &mut BytesMut) -> Result<(), Self::Error> {
self.codec.encode(message, buf)
}
}

154
mycelium/src/packet/data.rs Normal file
View File

@@ -0,0 +1,154 @@
use std::net::Ipv6Addr;
use bytes::{Buf, BufMut, BytesMut};
use tokio_util::codec::{Decoder, Encoder};
/// Size of the header start for a data packet (before the IP addresses).
const DATA_PACKET_HEADER_SIZE: usize = 4;
/// Mask to extract data length from
const DATA_PACKET_LEN_MASK: u32 = (1 << 16) - 1;
#[derive(Debug, Clone)]
pub struct DataPacket {
pub raw_data: Vec<u8>, // encrypted data itself, then append the nonce
/// Max amount of hops for the packet.
pub hop_limit: u8,
pub src_ip: Ipv6Addr,
pub dst_ip: Ipv6Addr,
}
pub struct Codec {
header_vals: Option<HeaderValues>,
src_ip: Option<Ipv6Addr>,
dest_ip: Option<Ipv6Addr>,
}
/// Data from the DataPacket header.
#[derive(Clone, Copy)]
struct HeaderValues {
len: u16,
hop_limit: u8,
}
impl Codec {
pub fn new() -> Self {
Codec {
header_vals: None,
src_ip: None,
dest_ip: None,
}
}
}
impl Decoder for Codec {
type Item = DataPacket;
type Error = std::io::Error;
fn decode(&mut self, src: &mut BytesMut) -> Result<Option<Self::Item>, Self::Error> {
// Determine the length of the data
let HeaderValues { len, hop_limit } = if let Some(header_vals) = self.header_vals {
header_vals
} else {
// Check we have enough data to decode
if src.len() < DATA_PACKET_HEADER_SIZE {
return Ok(None);
}
let raw_header = src.get_u32();
// Hop limit is the last 8 bits.
let hop_limit = (raw_header & 0xFF) as u8;
let data_len = ((raw_header >> 8) & DATA_PACKET_LEN_MASK) as u16;
let header_vals = HeaderValues {
len: data_len,
hop_limit,
};
self.header_vals = Some(header_vals);
header_vals
};
let data_len = len as usize;
// Determine the source IP
let src_ip = if let Some(src_ip) = self.src_ip {
src_ip
} else {
if src.len() < 16 {
return Ok(None);
}
// Decode octets
let mut ip_bytes = [0u8; 16];
ip_bytes.copy_from_slice(&src[..16]);
let src_ip = Ipv6Addr::from(ip_bytes);
src.advance(16);
self.src_ip = Some(src_ip);
src_ip
};
// Determine the destination IP
let dest_ip = if let Some(dest_ip) = self.dest_ip {
dest_ip
} else {
if src.len() < 16 {
return Ok(None);
}
// Decode octets
let mut ip_bytes = [0u8; 16];
ip_bytes.copy_from_slice(&src[..16]);
let dest_ip = Ipv6Addr::from(ip_bytes);
src.advance(16);
self.dest_ip = Some(dest_ip);
dest_ip
};
// Check we have enough data to decode
if src.len() < data_len {
return Ok(None);
}
// Decode octets
let mut data = vec![0u8; data_len];
data.copy_from_slice(&src[..data_len]);
src.advance(data_len);
// Reset state
self.header_vals = None;
self.dest_ip = None;
self.src_ip = None;
Ok(Some(DataPacket {
raw_data: data,
hop_limit,
dst_ip: dest_ip,
src_ip,
}))
}
}
impl Encoder<DataPacket> for Codec {
type Error = std::io::Error;
fn encode(&mut self, item: DataPacket, dst: &mut BytesMut) -> Result<(), Self::Error> {
dst.reserve(item.raw_data.len() + DATA_PACKET_HEADER_SIZE + 16 + 16);
let mut raw_header = 0;
// Add length of the data
raw_header |= (item.raw_data.len() as u32) << 8;
// And hop limit
raw_header |= item.hop_limit as u32;
dst.put_u32(raw_header);
// Write the source IP
dst.put_slice(&item.src_ip.octets());
// Write the destination IP
dst.put_slice(&item.dst_ip.octets());
// Write the data
dst.extend_from_slice(&item.raw_data);
Ok(())
}
}

401
mycelium/src/peer.rs Normal file
View File

@@ -0,0 +1,401 @@
use futures::{SinkExt, StreamExt};
use std::{
error::Error,
io,
sync::{
atomic::{AtomicBool, AtomicU64, Ordering},
Arc, RwLock, Weak,
},
};
use tokio::{
select,
sync::{mpsc, Notify},
};
use tokio_util::codec::Framed;
use tracing::{debug, error, info, trace};
use crate::{
connection::{self, Connection},
packet::{self, Packet},
};
use crate::{
packet::{ControlPacket, DataPacket},
sequence_number::SeqNo,
};
/// The maximum amount of packets to immediately send if they are ready when the first one is
/// received.
const PACKET_COALESCE_WINDOW: usize = 50;
/// The default link cost assigned to new peers before their actual cost is known.
///
/// In theory, the best value would be U16::MAX - 1, however this value would take too long to be
/// flushed out of the smoothed metric. A default of a 1000 (1 second) should be sufficiently large
/// to cover very bad connections, so they also converge to a smaller value. While there is no
/// issue with converging to a higher value (in other words, underestimating the latency to a
/// peer), this means that bad peers would briefly be more likely to be selected. Additionally,
/// since the latency increases, downstream peers would eventually find that the announced route
/// would become unfeasible, and send a seqno request (which should solve this efficiently). As a
/// tradeoff, it means it takes longer for new peers in the network to decrease to their actual
/// metric (in comparisson with a lower starting metric), though this is in itself a usefull thing
/// to have as it means peers joining the network would need to have some stability before being
/// selected as hop.
const DEFAULT_LINK_COST: u16 = 1000;
/// Multiplier for smoothed metric calculation of the existing smoothed metric.
const EXISTING_METRIC_FACTOR: u32 = 9;
/// Divisor for smoothed metric calcuation of the combined metric
const TOTAL_METRIC_DIVISOR: u32 = 10;
#[derive(Debug, Clone)]
/// A peer represents a directly connected participant in the network.
pub struct Peer {
inner: Arc<PeerInner>,
}
/// A weak reference to a peer, which does not prevent it from being cleaned up. This can be used
/// to check liveliness of the [`Peer`] instance it originated from.
pub struct PeerRef {
inner: Weak<PeerInner>,
}
impl Peer {
pub fn new<C: Connection + Unpin + Send + 'static>(
router_data_tx: mpsc::Sender<DataPacket>,
router_control_tx: mpsc::UnboundedSender<(ControlPacket, Peer)>,
connection: C,
dead_peer_sink: mpsc::Sender<Peer>,
bytes_written: Arc<AtomicU64>,
bytes_read: Arc<AtomicU64>,
) -> Result<Self, io::Error> {
// Wrap connection so we can get access to the counters.
let connection = connection::Tracked::new(bytes_read, bytes_written, connection);
// Data channel for peer
let (to_peer_data, mut from_routing_data) = mpsc::unbounded_channel::<DataPacket>();
// Control channel for peer
let (to_peer_control, mut from_routing_control) =
mpsc::unbounded_channel::<ControlPacket>();
let death_notifier = Arc::new(Notify::new());
let death_watcher = death_notifier.clone();
let peer = Peer {
inner: Arc::new(PeerInner {
state: RwLock::new(PeerState::new()),
to_peer_data,
to_peer_control,
connection_identifier: connection.identifier()?,
static_link_cost: connection.static_link_cost()?,
death_notifier,
alive: AtomicBool::new(true),
}),
};
// Framed for peer
// Used to send and receive packets from a TCP stream
let framed = Framed::with_capacity(connection, packet::Codec::new(), 128 << 10);
let (mut sink, mut stream) = framed.split();
{
let peer = peer.clone();
tokio::spawn(async move {
let mut needs_flush = false;
loop {
select! {
// Received over the TCP stream
frame = stream.next() => {
match frame {
Some(Ok(packet)) => {
match packet {
Packet::DataPacket(packet) => {
// An error here means the receiver is dropped/closed,
// this is not recoverable.
if let Err(error) = router_data_tx.send(packet).await{
error!("Error sending to to_routing_data: {}", error);
break
}
}
Packet::ControlPacket(packet) => {
if let Err(error) = router_control_tx.send((packet, peer.clone())) {
// An error here means the receiver is dropped/closed,
// this is not recoverable.
error!("Error sending to to_routing_control: {}", error);
break
}
}
}
}
Some(Err(e)) => {
error!("Frame error from {}: {e}", peer.connection_identifier());
break;
},
None => {
info!("Stream to {} is closed", peer.connection_identifier());
break;
}
}
}
rv = from_routing_data.recv(), if !needs_flush => {
match rv {
None => break,
Some(packet) => {
needs_flush = true;
if let Err(e) = sink.feed(Packet::DataPacket(packet)).await {
error!("Failed to feed data packet to connection: {e}");
break
}
for _ in 1..PACKET_COALESCE_WINDOW {
// There can be 2 cases of errors here, empty channel and no more
// senders. In both cases we don't really care at this point.
if let Ok(packet) = from_routing_data.try_recv() {
if let Err(e) = sink.feed(Packet::DataPacket(packet)).await {
error!("Failed to feed data packet to connection: {e}");
break
}
trace!("Instantly queued ready packet to transfer to peer");
} else {
// No packets ready, flush currently buffered ones
break
}
}
}
}
}
rv = from_routing_control.recv(), if !needs_flush => {
match rv {
None => break,
Some(packet) => {
needs_flush = true;
if let Err(e) = sink.feed(Packet::ControlPacket(packet)).await {
error!("Failed to feed control packet to connection: {e}");
break
}
for _ in 1..PACKET_COALESCE_WINDOW {
// There can be 2 cases of errors here, empty channel and no more
// senders. In both cases we don't really care at this point.
if let Ok(packet) = from_routing_control.try_recv() {
if let Err(e) = sink.feed(Packet::ControlPacket(packet)).await {
error!("Failed to feed data packet to connection: {e}");
break
}
} else {
// No packets ready, flush currently buffered ones
break
}
}
}
}
}
r = sink.flush(), if needs_flush => {
if let Err(err) = r {
error!("Failed to flush peer connection: {err}");
break
}
needs_flush = false;
}
_ = death_watcher.notified() => {
// Attempt gracefull shutdown
let mut framed = sink.reunite(stream).expect("SplitSink and SplitStream here can only be part of the same original Framned; Qed");
let _ = framed.close().await;
break;
}
}
}
// Notify router we are dead, also modify our internal state to declare that.
// Relaxed ordering is fine, we just care that the variable is set.
peer.inner.alive.store(false, Ordering::Relaxed);
let remote_id = peer.connection_identifier().clone();
debug!("Notifying router peer {remote_id} is dead");
if let Err(e) = dead_peer_sink.send(peer).await {
error!("Peer {remote_id} could not notify router of termination: {e}");
}
});
};
Ok(peer)
}
/// Get current sequence number for this peer.
pub fn hello_seqno(&self) -> SeqNo {
self.inner.state.read().unwrap().hello_seqno
}
/// Adds 1 to the sequence number of this peer .
pub fn increment_hello_seqno(&self) {
self.inner.state.write().unwrap().hello_seqno += 1;
}
pub fn time_last_received_hello(&self) -> tokio::time::Instant {
self.inner.state.read().unwrap().time_last_received_hello
}
pub fn set_time_last_received_hello(&self, time: tokio::time::Instant) {
self.inner.state.write().unwrap().time_last_received_hello = time
}
/// For sending data packets towards a peer instance on this node.
/// It's send over the to_peer_data channel and read from the corresponding receiver.
/// The receiver sends the packet over the TCP stream towards the destined peer instance on another node
pub fn send_data_packet(&self, data_packet: DataPacket) -> Result<(), Box<dyn Error>> {
Ok(self.inner.to_peer_data.send(data_packet)?)
}
/// For sending control packets towards a peer instance on this node.
/// It's send over the to_peer_control channel and read from the corresponding receiver.
/// The receiver sends the packet over the TCP stream towards the destined peer instance on another node
pub fn send_control_packet(&self, control_packet: ControlPacket) -> Result<(), Box<dyn Error>> {
Ok(self.inner.to_peer_control.send(control_packet)?)
}
/// Get the cost to use the peer, i.e. the additional impact on the [`crate::metric::Metric`]
/// for using this `Peer`.
///
/// This is a smoothed value, which is calculated over the recent history of link cost.
pub fn link_cost(&self) -> u16 {
self.inner.state.read().unwrap().link_cost + self.inner.static_link_cost
}
/// Sets the link cost based on the provided value.
///
/// The link cost is not set to the given value, but rather to an average of recent values.
/// This makes sure short-lived, hard spikes of the link cost of a peer don't influence the
/// routing.
pub fn set_link_cost(&self, new_link_cost: u16) {
// Calculate new link cost by multiplying (i.e. scaling) old and new link cost and
// averaging them.
let mut inner = self.inner.state.write().unwrap();
inner.link_cost = (((inner.link_cost as u32) * EXISTING_METRIC_FACTOR
+ (new_link_cost as u32) * (TOTAL_METRIC_DIVISOR - EXISTING_METRIC_FACTOR))
/ TOTAL_METRIC_DIVISOR) as u16;
}
/// Identifier for the connection to the `Peer`.
pub fn connection_identifier(&self) -> &String {
&self.inner.connection_identifier
}
pub fn time_last_received_ihu(&self) -> tokio::time::Instant {
self.inner.state.read().unwrap().time_last_received_ihu
}
pub fn set_time_last_received_ihu(&self, time: tokio::time::Instant) {
self.inner.state.write().unwrap().time_last_received_ihu = time
}
/// Notify this `Peer` that it died.
///
/// While some [`Connection`] types can immediately detect that the connection itself is
/// broken, not all of them can. In this scenario, we need to rely on an outside signal to tell
/// us that we have, in fact, died.
pub fn died(&self) {
self.inner.alive.store(false, Ordering::Relaxed);
self.inner.death_notifier.notify_one();
}
/// Checks if the connection of this `Peer` is still alive.
///
/// For connection types which don't have (real time) state information, this might return a
/// false positive if the connection has actually died, but the Peer did not notice this (yet)
/// and hasn't been informed.
pub fn alive(&self) -> bool {
self.inner.alive.load(Ordering::Relaxed)
}
/// Create a new [`PeerRef`] that refers to this `Peer` instance.
pub fn refer(&self) -> PeerRef {
PeerRef {
inner: Arc::downgrade(&self.inner),
}
}
}
impl PeerRef {
/// Contructs a new `PeerRef` which is not associated with any actually [`Peer`].
/// [`PeerRef::alive`] will always return false when called on this `PeerRef`.
pub fn new() -> Self {
PeerRef { inner: Weak::new() }
}
/// Check if the connection of the [`Peer`] this `PeerRef` points to is still alive.
pub fn alive(&self) -> bool {
if let Some(peer) = self.inner.upgrade() {
peer.alive.load(Ordering::Relaxed)
} else {
false
}
}
/// Attempts to convert this `PeerRef` into a full [`Peer`].
pub fn upgrade(&self) -> Option<Peer> {
self.inner.upgrade().map(|inner| Peer { inner })
}
}
impl Default for PeerRef {
fn default() -> Self {
Self::new()
}
}
impl PartialEq for Peer {
fn eq(&self, other: &Self) -> bool {
Arc::ptr_eq(&self.inner, &other.inner)
}
}
#[derive(Debug)]
struct PeerInner {
state: RwLock<PeerState>,
to_peer_data: mpsc::UnboundedSender<DataPacket>,
to_peer_control: mpsc::UnboundedSender<ControlPacket>,
/// Used to identify peer based on its connection params.
connection_identifier: String,
/// Static cost of using this link, to be added to the announced metric for routes through this
/// Peer.
static_link_cost: u16,
/// Channel to notify the connection of its decease.
death_notifier: Arc<Notify>,
/// Keep track if the connection is alive.
alive: AtomicBool,
}
#[derive(Debug)]
struct PeerState {
hello_seqno: SeqNo,
time_last_received_hello: tokio::time::Instant,
link_cost: u16,
time_last_received_ihu: tokio::time::Instant,
}
impl PeerState {
/// Create a new `PeerInner`, holding the mutable state of a [`Peer`]
fn new() -> Self {
// Initialize last_sent_hello_seqno to 0
let hello_seqno = SeqNo::default();
let link_cost = DEFAULT_LINK_COST;
// Initialize time_last_received_hello to now
let time_last_received_hello = tokio::time::Instant::now();
// Initialiwe time_last_send_ihu
let time_last_received_ihu = tokio::time::Instant::now();
Self {
hello_seqno,
link_cost,
time_last_received_ihu,
time_last_received_hello,
}
}
}

1361
mycelium/src/peer_manager.rs Normal file

File diff suppressed because it is too large Load Diff

2216
mycelium/src/router.rs Normal file

File diff suppressed because it is too large Load Diff

60
mycelium/src/router_id.rs Normal file
View File

@@ -0,0 +1,60 @@
use core::fmt;
use crate::crypto::PublicKey;
/// A `RouterId` uniquely identifies a router in the network.
#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)]
pub struct RouterId {
pk: PublicKey,
zone: [u8; 2],
rnd: [u8; 6],
}
impl RouterId {
/// Size in bytes of a `RouterId`
pub const BYTE_SIZE: usize = 40;
/// Create a new `RouterId` from a [`PublicKey`].
pub fn new(pk: PublicKey) -> Self {
Self {
pk,
zone: [0; 2],
rnd: rand::random(),
}
}
/// View this `RouterId` as a byte array.
pub fn as_bytes(&self) -> [u8; Self::BYTE_SIZE] {
let mut out = [0; Self::BYTE_SIZE];
out[..32].copy_from_slice(self.pk.as_bytes());
out[32..34].copy_from_slice(&self.zone);
out[34..].copy_from_slice(&self.rnd);
out
}
/// Converts this `RouterId` to a [`PublicKey`].
pub fn to_pubkey(self) -> PublicKey {
self.pk
}
}
impl From<[u8; Self::BYTE_SIZE]> for RouterId {
fn from(bytes: [u8; Self::BYTE_SIZE]) -> RouterId {
RouterId {
pk: PublicKey::from(<&[u8] as TryInto<[u8; 32]>>::try_into(&bytes[..32]).unwrap()),
zone: bytes[32..34].try_into().unwrap(),
rnd: bytes[34..Self::BYTE_SIZE].try_into().unwrap(),
}
}
}
impl fmt::Display for RouterId {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
let RouterId { pk, zone, rnd } = self;
f.write_fmt(format_args!(
"{pk}-{}-{}",
faster_hex::hex_string(zone),
faster_hex::hex_string(rnd)
))
}
}

View File

@@ -0,0 +1,687 @@
use std::{
net::{IpAddr, Ipv6Addr},
ops::Deref,
sync::{Arc, Mutex, MutexGuard},
};
use ip_network_table_deps_treebitmap::IpLookupTable;
use iter::{RoutingTableNoRouteIter, RoutingTableQueryIter};
use subnet_entry::SubnetEntry;
use tokio::{select, sync::mpsc, time::Duration};
use tokio_util::sync::CancellationToken;
use tracing::{error, trace};
use crate::{crypto::SharedSecret, peer::Peer, subnet::Subnet};
pub use iter::RoutingTableIter;
pub use iter_mut::RoutingTableIterMut;
pub use no_route::NoRouteSubnet;
pub use queried_subnet::QueriedSubnet;
pub use route_entry::RouteEntry;
pub use route_key::RouteKey;
pub use route_list::RouteList;
mod iter;
mod iter_mut;
mod no_route;
mod queried_subnet;
mod route_entry;
mod route_key;
mod route_list;
mod subnet_entry;
const NO_ROUTE_EXPIRATION: Duration = Duration::from_secs(60);
pub enum Routes {
Exist(RouteListReadGuard),
Queried,
NoRoute,
None,
}
impl Routes {
/// Returns the selected route if one exists.
pub fn selected(&self) -> Option<&RouteEntry> {
if let Routes::Exist(routes) = self {
routes.selected()
} else {
None
}
}
/// Returns true if there are no routes
pub fn is_none(&self) -> bool {
!matches!(self, Routes::Exist { .. })
}
}
impl From<&SubnetEntry> for Routes {
fn from(value: &SubnetEntry) -> Self {
match value {
SubnetEntry::Exists { list } => {
Routes::Exist(RouteListReadGuard { inner: list.load() })
}
SubnetEntry::Queried { .. } => Routes::Queried,
SubnetEntry::NoRoute { .. } => Routes::NoRoute,
}
}
}
impl From<Option<&SubnetEntry>> for Routes {
fn from(value: Option<&SubnetEntry>) -> Self {
match value {
Some(v) => v.into(),
None => Routes::None,
}
}
}
/// The routing table holds a list of route entries for every known subnet.
#[derive(Clone)]
pub struct RoutingTable {
writer: Arc<Mutex<left_right::WriteHandle<RoutingTableInner, RoutingTableOplogEntry>>>,
reader: left_right::ReadHandle<RoutingTableInner>,
shared: Arc<RoutingTableShared>,
}
struct RoutingTableShared {
expired_route_entry_sink: mpsc::Sender<RouteKey>,
cancel_token: CancellationToken,
}
#[derive(Default)]
struct RoutingTableInner {
table: IpLookupTable<Ipv6Addr, Arc<SubnetEntry>>,
}
/// Hold an exclusive write lock over the routing table. While this item is in scope, no other
/// calls can get a mutable refernce to the content of a routing table. Once this guard goes out of
/// scope, changes to the contained RouteList will be applied.
pub struct WriteGuard<'a> {
routing_table: &'a RoutingTable,
/// Owned copy of the RouteList, this is populated once mutable access the the RouteList has
/// been requested.
value: Arc<SubnetEntry>,
/// Did the RouteList exist initially?
exists: bool,
/// The subnet we are writing to.
subnet: Subnet,
expired_route_entry_sink: mpsc::Sender<RouteKey>,
cancellation_token: CancellationToken,
}
impl RoutingTable {
/// Create a new empty RoutingTable. The passed channel is used to notify an external observer
/// of route entry expiration events. It is the callers responsibility to ensure these events
/// are properly handled.
///
/// # Panics
///
/// This will panic if not executed in the context of a tokio runtime.
pub fn new(expired_route_entry_sink: mpsc::Sender<RouteKey>) -> Self {
let (writer, reader) = left_right::new();
let writer = Arc::new(Mutex::new(writer));
let cancel_token = CancellationToken::new();
let shared = Arc::new(RoutingTableShared {
expired_route_entry_sink,
cancel_token,
});
RoutingTable {
writer,
reader,
shared,
}
}
/// Get a list of the routes for the most precises [`Subnet`] known which contains the given
/// [`IpAddr`].
pub fn best_routes(&self, ip: IpAddr) -> Routes {
let IpAddr::V6(ip) = ip else {
panic!("Only IPv6 is supported currently");
};
self.reader
.enter()
.expect("Write handle is saved on the router so it is not dropped yet.")
.table
.longest_match(ip)
.map(|(_, _, rl)| rl.as_ref())
.into()
}
/// Get a list of all routes for the given subnet. Changes to the RoutingTable after this
/// method returns will not be visible and require this method to be called again to be
/// observed.
pub fn routes(&self, subnet: Subnet) -> Routes {
let subnet_ip = if let IpAddr::V6(ip) = subnet.address() {
ip
} else {
return Routes::None;
};
self.reader
.enter()
.expect("Write handle is saved on the router so it is not dropped yet.")
.table
.exact_match(subnet_ip, subnet.prefix_len().into())
.map(Arc::as_ref)
.into()
}
/// Gets continued read access to the `RoutingTable`. While the returned
/// [`guard`](RoutingTableReadGuard) is held, updates to the `RoutingTable` will be blocked.
pub fn read(&self) -> RoutingTableReadGuard {
RoutingTableReadGuard {
guard: self
.reader
.enter()
.expect("Write handle is saved on RoutingTable, so this is always Some; qed"),
}
}
/// Locks the `RoutingTable` for continued write access. While the returned
/// [`guard`](RoutingTableWriteGuard) is held, methods trying to mutate the `RoutingTable`, or
/// get mutable access otherwise, will be blocked. When the [`guard`](`RoutingTableWriteGuard`)
/// is dropped, all queued changes will be applied.
pub fn write(&self) -> RoutingTableWriteGuard {
RoutingTableWriteGuard {
write_guard: self.writer.lock().unwrap(),
read_guard: self
.reader
.enter()
.expect("Write handle is saved on RoutingTable, so this is always Some; qed"),
expired_route_entry_sink: self.shared.expired_route_entry_sink.clone(),
cancel_token: self.shared.cancel_token.clone(),
}
}
/// Get mutable access to the list of routes for the given [`Subnet`].
pub fn routes_mut(&self, subnet: Subnet) -> Option<WriteGuard> {
let subnet_address = if let IpAddr::V6(ip) = subnet.address() {
ip
} else {
panic!("IP v4 addresses are not supported")
};
let value = self
.reader
.enter()
.expect("Write handle is saved next to read handle so this is always Some; qed")
.table
.exact_match(subnet_address, subnet.prefix_len().into())?
.clone();
if matches!(*value, SubnetEntry::Exists { .. }) {
Some(WriteGuard {
routing_table: self,
// If we didn't find a route list in the route table we create a new empty list,
// therefore we immediately own it.
value,
exists: true,
subnet,
expired_route_entry_sink: self.shared.expired_route_entry_sink.clone(),
cancellation_token: self.shared.cancel_token.clone(),
})
} else {
None
}
}
/// Adds a new [`Subnet`] to the `RoutingTable`. The returned [`WriteGuard`] can be used to
/// insert entries. If no entry is inserted before the guard is dropped, the [`Subnet`] won't
/// be added.
pub fn add_subnet(&self, subnet: Subnet, shared_secret: SharedSecret) -> WriteGuard {
if !matches!(subnet.address(), IpAddr::V6(_)) {
panic!("IP v4 addresses are not supported")
};
let value = Arc::new(SubnetEntry::Exists {
list: Arc::new(RouteList::new(shared_secret)).into(),
});
WriteGuard {
routing_table: self,
value,
exists: false,
subnet,
expired_route_entry_sink: self.shared.expired_route_entry_sink.clone(),
cancellation_token: self.shared.cancel_token.clone(),
}
}
/// Gets the selected route for an IpAddr if one exists.
///
/// # Panics
///
/// This will panic if the IP address is not an IPV6 address.
pub fn selected_route(&self, address: IpAddr) -> Option<RouteEntry> {
let IpAddr::V6(ip) = address else {
panic!("IP v4 addresses are not supported")
};
self.reader
.enter()
.expect("Write handle is saved on RoutingTable, so this is always Some; qed")
.table
.longest_match(ip)
.and_then(|(_, _, rl)| {
let SubnetEntry::Exists { list } = &**rl else {
return None;
};
let rl = list.load();
if rl.is_empty() || !rl[0].selected() {
None
} else {
Some(rl[0].clone())
}
})
}
/// Marks a subnet as queried in the route table.
///
/// This function will not do anything if the subnet contains valid routes.
pub fn mark_queried(&self, subnet: Subnet, query_timeout: tokio::time::Instant) {
if !matches!(subnet.address(), IpAddr::V6(_)) {
panic!("IP v4 addresses are not supported")
};
// Start a task to expire the queried state if we didn't have any results in time.
{
// We only need the write handle in the task
let writer = self.writer.clone();
let cancel_token = self.shared.cancel_token.clone();
tokio::task::spawn(async move {
select! {
_ = cancel_token.cancelled() => {
// Future got cancelled, nothing to do
return
}
_ = tokio::time::sleep_until(query_timeout) => {
// Timeout fired, mark as no route
}
}
let expiry = tokio::time::Instant::now() + NO_ROUTE_EXPIRATION;
// Scope this so the lock for the write_handle goes out of scope when we are done
// here, as we don't want to hold the write_handle lock while sleeping for the
// second timeout.
{
let mut write_handle = writer.lock().expect("Can lock writer");
write_handle.append(RoutingTableOplogEntry::QueryExpired(
subnet,
Arc::new(SubnetEntry::NoRoute { expiry }),
));
write_handle.flush();
}
// TODO: Check if we are indeed marked as NoRoute here, if we aren't this can be
// cancelled now
select! {
_ = cancel_token.cancelled() => {
// Future got cancelled, nothing to do
return
}
_ = tokio::time::sleep_until(expiry) => {
// Timeout fired, remove no route entry
}
}
let mut write_handle = writer.lock().expect("Can lock writer");
write_handle.append(RoutingTableOplogEntry::NoRouteExpired(subnet));
write_handle.flush();
});
}
let mut write_handle = self.writer.lock().expect("Can lock writer");
write_handle.append(RoutingTableOplogEntry::Queried(
subnet,
Arc::new(SubnetEntry::Queried { query_timeout }),
));
write_handle.flush();
}
}
pub struct RouteListReadGuard {
inner: arc_swap::Guard<Arc<RouteList>>,
}
impl Deref for RouteListReadGuard {
type Target = RouteList;
fn deref(&self) -> &Self::Target {
self.inner.deref()
}
}
/// A write guard over the [`RoutingTable`]. While this guard is held, updates won't be able to
/// complete.
pub struct RoutingTableWriteGuard<'a> {
write_guard: MutexGuard<'a, left_right::WriteHandle<RoutingTableInner, RoutingTableOplogEntry>>,
read_guard: left_right::ReadGuard<'a, RoutingTableInner>,
expired_route_entry_sink: mpsc::Sender<RouteKey>,
cancel_token: CancellationToken,
}
impl<'a, 'b> RoutingTableWriteGuard<'a> {
pub fn iter_mut(&'b mut self) -> RoutingTableIterMut<'a, 'b> {
RoutingTableIterMut::new(
&mut self.write_guard,
self.read_guard.table.iter(),
self.expired_route_entry_sink.clone(),
self.cancel_token.clone(),
)
}
}
impl Drop for RoutingTableWriteGuard<'_> {
fn drop(&mut self) {
self.write_guard.publish();
}
}
/// A read guard over the [`RoutingTable`]. While this guard is held, updates won't be able to
/// complete.
pub struct RoutingTableReadGuard<'a> {
guard: left_right::ReadGuard<'a, RoutingTableInner>,
}
impl RoutingTableReadGuard<'_> {
pub fn iter(&self) -> RoutingTableIter {
RoutingTableIter::new(self.guard.table.iter())
}
/// Create an iterator for all queried subnets in the routing table
pub fn iter_queries(&self) -> RoutingTableQueryIter {
RoutingTableQueryIter::new(self.guard.table.iter())
}
/// Create an iterator for all subnets which are currently marked as `NoRoute` in the routing
/// table.
pub fn iter_no_route(&self) -> RoutingTableNoRouteIter {
RoutingTableNoRouteIter::new(self.guard.table.iter())
}
}
impl WriteGuard<'_> {
/// Loads the current [`RouteList`].
#[inline]
pub fn routes(&self) -> RouteListReadGuard {
let SubnetEntry::Exists { list } = &*self.value else {
panic!("Write guard for non-route SubnetEntry")
};
RouteListReadGuard { inner: list.load() }
}
/// Get mutable access to the [`RouteList`]. This will update the [`RouteList`] in place
/// without locking the [`RoutingTable`].
// TODO: Proper abstractions
pub fn update_routes<
F: FnMut(&mut RouteList, &mpsc::Sender<RouteKey>, &CancellationToken) -> bool,
>(
&mut self,
mut op: F,
) -> bool {
let mut res = false;
let mut delete = false;
if let SubnetEntry::Exists { list } = &*self.value {
list.rcu(|rl| {
let mut new_val = rl.clone();
let v = Arc::make_mut(&mut new_val);
res = op(v, &self.expired_route_entry_sink, &self.cancellation_token);
delete = v.is_empty();
new_val
});
if delete && self.exists {
trace!(subnet = %self.subnet, "Deleting subnet which became empty after updating");
let mut writer = self.routing_table.writer.lock().unwrap();
writer.append(RoutingTableOplogEntry::Delete(self.subnet));
writer.publish();
}
res
} else {
false
}
}
/// Set the [`RouteEntry`] with the given [`neighbour`](Peer) as the selected route.
pub fn set_selected(&mut self, neighbour: &Peer) {
if let SubnetEntry::Exists { list } = &*self.value {
list.rcu(|routes| {
let mut new_routes = routes.clone();
let routes = Arc::make_mut(&mut new_routes);
let Some(pos) = routes.iter().position(|re| re.neighbour() == neighbour) else {
error!(
neighbour = neighbour.connection_identifier(),
"Failed to select route entry with given route key, no such entry"
);
return new_routes;
};
// We don't need a check for an empty list here, since we found a selected route there
// _MUST_ be at least 1 entry.
// Set the first element to unselected, then select the proper element so this also works
// in case the existing route is "reselected".
routes[0].set_selected(false);
routes[pos].set_selected(true);
routes.swap(0, pos);
new_routes
});
}
}
/// Unconditionally unselects the selected route, if one is present.
///
/// In case no route is selected, this is a no-op.
pub fn unselect(&mut self) {
if let SubnetEntry::Exists { list } = &*self.value {
list.rcu(|v| {
let mut new_val = v.clone();
let new_ref = Arc::make_mut(&mut new_val);
if let Some(e) = new_ref.get_mut(0) {
e.set_selected(false);
}
new_val
});
}
}
}
impl Drop for WriteGuard<'_> {
fn drop(&mut self) {
// FIXME: try to get rid of clones on the Arc here
if let SubnetEntry::Exists { list } = &*self.value {
let value = list.load();
match self.exists {
// The route list did not exist, and now it is not empty, so an entry was added. We
// need to add the route list to the routing table.
false if !value.is_empty() => {
trace!(subnet = %self.subnet, "Inserting new route list for subnet");
let mut writer = self.routing_table.writer.lock().unwrap();
writer.append(RoutingTableOplogEntry::Upsert(
self.subnet,
Arc::clone(&self.value),
));
writer.publish();
}
// There was an existing route list which is now empty, so the entry for this subnet
// needs to be deleted in the routing table.
true if value.is_empty() => {
trace!(subnet = %self.subnet, "Removing route list for subnet");
let mut writer = self.routing_table.writer.lock().unwrap();
writer.append(RoutingTableOplogEntry::Delete(self.subnet));
writer.publish();
}
// Nothing to do in these cases. Either no value was inserted in a non existing
// routelist, or an existing one was updated in place.
_ => {}
}
}
}
}
/// Operations allowed on the left_right for the routing table.
enum RoutingTableOplogEntry {
/// Insert or Update the value for the given subnet.
Upsert(Subnet, Arc<SubnetEntry>),
/// Mark a subnet as queried.
Queried(Subnet, Arc<SubnetEntry>),
/// Delete the entry for the given subnet.
Delete(Subnet),
/// The route request for a subnet expired, if it is still in query state mark it as not
/// existing
QueryExpired(Subnet, Arc<SubnetEntry>),
/// The marker for explicitly not having a route to a subnet has expired
NoRouteExpired(Subnet),
}
/// Convert an [`IpAddr`] into an [`Ipv6Addr`]. Panics if the contained addrss is not an IPv6
/// address.
fn expect_ipv6(ip: IpAddr) -> Ipv6Addr {
let IpAddr::V6(ip) = ip else {
panic!("Expected ipv6 address")
};
ip
}
impl left_right::Absorb<RoutingTableOplogEntry> for RoutingTableInner {
fn absorb_first(&mut self, operation: &mut RoutingTableOplogEntry, _other: &Self) {
match operation {
RoutingTableOplogEntry::Upsert(subnet, list) => {
self.table.insert(
expect_ipv6(subnet.address()),
subnet.prefix_len().into(),
Arc::clone(list),
);
}
RoutingTableOplogEntry::Queried(subnet, se) => {
// Mark a query only if we don't have a valid entry
let entry = self
.table
.exact_match(expect_ipv6(subnet.address()), subnet.prefix_len().into())
.map(Arc::deref);
// If we have no route, transition to query, if we have a route or existing query,
// do nothing
if matches!(entry, None | Some(SubnetEntry::NoRoute { .. })) {
self.table.insert(
expect_ipv6(subnet.address()),
subnet.prefix_len().into(),
Arc::clone(se),
);
}
}
RoutingTableOplogEntry::Delete(subnet) => {
self.table
.remove(expect_ipv6(subnet.address()), subnet.prefix_len().into());
}
RoutingTableOplogEntry::QueryExpired(subnet, nre) => {
if let Some(entry) = self
.table
.exact_match(expect_ipv6(subnet.address()), subnet.prefix_len().into())
{
if let SubnetEntry::Queried { .. } = &**entry {
self.table.insert(
expect_ipv6(subnet.address()),
subnet.prefix_len().into(),
Arc::clone(nre),
);
}
}
}
RoutingTableOplogEntry::NoRouteExpired(subnet) => {
if let Some(entry) = self
.table
.exact_match(expect_ipv6(subnet.address()), subnet.prefix_len().into())
{
if let SubnetEntry::NoRoute { .. } = &**entry {
self.table
.remove(expect_ipv6(subnet.address()), subnet.prefix_len().into());
}
}
}
}
}
fn sync_with(&mut self, first: &Self) {
for (k, ss, v) in first.table.iter() {
self.table.insert(k, ss, v.clone());
}
}
fn absorb_second(&mut self, operation: RoutingTableOplogEntry, _: &Self) {
match operation {
RoutingTableOplogEntry::Upsert(subnet, list) => {
self.table.insert(
expect_ipv6(subnet.address()),
subnet.prefix_len().into(),
list,
);
}
RoutingTableOplogEntry::Queried(subnet, se) => {
// Mark a query only if we don't have a valid entry
let entry = self
.table
.exact_match(expect_ipv6(subnet.address()), subnet.prefix_len().into())
.map(Arc::deref);
// If we have no route, transition to query, if we have a route or existing query,
// do nothing
if matches!(entry, None | Some(SubnetEntry::NoRoute { .. })) {
self.table.insert(
expect_ipv6(subnet.address()),
subnet.prefix_len().into(),
se,
);
}
}
RoutingTableOplogEntry::Delete(subnet) => {
self.table
.remove(expect_ipv6(subnet.address()), subnet.prefix_len().into());
}
RoutingTableOplogEntry::QueryExpired(subnet, nre) => {
if let Some(entry) = self
.table
.exact_match(expect_ipv6(subnet.address()), subnet.prefix_len().into())
{
if let SubnetEntry::Queried { .. } = &**entry {
self.table.insert(
expect_ipv6(subnet.address()),
subnet.prefix_len().into(),
nre,
);
}
}
}
RoutingTableOplogEntry::NoRouteExpired(subnet) => {
if let Some(entry) = self
.table
.exact_match(expect_ipv6(subnet.address()), subnet.prefix_len().into())
{
if let SubnetEntry::NoRoute { .. } = &**entry {
self.table
.remove(expect_ipv6(subnet.address()), subnet.prefix_len().into());
}
}
}
}
}
}
impl Drop for RoutingTableShared {
fn drop(&mut self) {
self.cancel_token.cancel();
}
}

View File

@@ -0,0 +1,100 @@
use std::{net::Ipv6Addr, sync::Arc};
use crate::subnet::Subnet;
use super::{subnet_entry::SubnetEntry, NoRouteSubnet, QueriedSubnet, RouteListReadGuard};
/// An iterator over a [`routing table`](super::RoutingTable) giving read only access to
/// [`RouteList`]'s.
pub struct RoutingTableIter<'a>(
ip_network_table_deps_treebitmap::Iter<'a, Ipv6Addr, Arc<SubnetEntry>>,
);
impl<'a> RoutingTableIter<'a> {
/// Create a new `RoutingTableIter` which will iterate over all entries in a [`RoutingTable`].
pub(super) fn new(
inner: ip_network_table_deps_treebitmap::Iter<'a, Ipv6Addr, Arc<SubnetEntry>>,
) -> Self {
Self(inner)
}
}
impl Iterator for RoutingTableIter<'_> {
type Item = (Subnet, RouteListReadGuard);
fn next(&mut self) -> Option<Self::Item> {
for (ip, prefix_size, rl) in self.0.by_ref() {
if let SubnetEntry::Exists { list } = &**rl {
return Some((
Subnet::new(ip.into(), prefix_size as u8)
.expect("Routing table contains valid subnets"),
RouteListReadGuard { inner: list.load() },
));
}
}
None
}
}
/// Iterator over queried routes in the routing table.
pub struct RoutingTableQueryIter<'a>(
ip_network_table_deps_treebitmap::Iter<'a, Ipv6Addr, Arc<SubnetEntry>>,
);
impl<'a> RoutingTableQueryIter<'a> {
/// Create a new `RoutingTableQueryIter` which will iterate over all queried entries in a [`RoutingTable`].
pub(super) fn new(
inner: ip_network_table_deps_treebitmap::Iter<'a, Ipv6Addr, Arc<SubnetEntry>>,
) -> Self {
Self(inner)
}
}
impl Iterator for RoutingTableQueryIter<'_> {
type Item = QueriedSubnet;
fn next(&mut self) -> Option<Self::Item> {
for (ip, prefix_size, rl) in self.0.by_ref() {
if let SubnetEntry::Queried { query_timeout } = &**rl {
return Some(QueriedSubnet::new(
Subnet::new(ip.into(), prefix_size as u8)
.expect("Routing table contains valid subnets"),
*query_timeout,
));
}
}
None
}
}
/// Iterator for entries which are explicitly marked as "no route"in the routing table.
pub struct RoutingTableNoRouteIter<'a>(
ip_network_table_deps_treebitmap::Iter<'a, Ipv6Addr, Arc<SubnetEntry>>,
);
impl<'a> RoutingTableNoRouteIter<'a> {
/// Create a new `RoutingTableNoRouteIter` which will iterate over all entries in a [`RoutingTable`]
/// which are explicitly marked as `NoRoute`
pub(super) fn new(
inner: ip_network_table_deps_treebitmap::Iter<'a, Ipv6Addr, Arc<SubnetEntry>>,
) -> Self {
Self(inner)
}
}
impl Iterator for RoutingTableNoRouteIter<'_> {
type Item = NoRouteSubnet;
fn next(&mut self) -> Option<Self::Item> {
for (ip, prefix_size, rl) in self.0.by_ref() {
if let SubnetEntry::NoRoute { expiry } = &**rl {
return Some(NoRouteSubnet::new(
Subnet::new(ip.into(), prefix_size as u8)
.expect("Routing table contains valid subnets"),
*expiry,
));
}
}
None
}
}

View File

@@ -0,0 +1,107 @@
use tokio::sync::mpsc;
use tokio_util::sync::CancellationToken;
use tracing::trace;
use crate::subnet::Subnet;
use super::{
subnet_entry::SubnetEntry, RouteKey, RouteList, RoutingTableInner, RoutingTableOplogEntry,
};
use std::{
net::Ipv6Addr,
sync::{Arc, MutexGuard},
};
/// An iterator over a [`routing table`](super::RoutingTable), yielding mutable access to the
/// entries in the table.
pub struct RoutingTableIterMut<'a, 'b> {
write_guard:
&'b mut MutexGuard<'a, left_right::WriteHandle<RoutingTableInner, RoutingTableOplogEntry>>,
iter: ip_network_table_deps_treebitmap::Iter<'b, Ipv6Addr, Arc<SubnetEntry>>,
expired_route_entry_sink: mpsc::Sender<RouteKey>,
cancel_token: CancellationToken,
}
impl<'a, 'b> RoutingTableIterMut<'a, 'b> {
pub(super) fn new(
write_guard: &'b mut MutexGuard<
'a,
left_right::WriteHandle<RoutingTableInner, RoutingTableOplogEntry>,
>,
iter: ip_network_table_deps_treebitmap::Iter<'b, Ipv6Addr, Arc<SubnetEntry>>,
expired_route_entry_sink: mpsc::Sender<RouteKey>,
cancel_token: CancellationToken,
) -> Self {
Self {
write_guard,
iter,
expired_route_entry_sink,
cancel_token,
}
}
/// Get the next item in this iterator. This is not implemented as the [`Iterator`] trait,
/// since we hand out items which are lifetime bound to this struct.
pub fn next<'c>(&'c mut self) -> Option<(Subnet, RoutingTableIterMutEntry<'a, 'c>)> {
for (ip, prefix_size, rl) in self.iter.by_ref() {
if matches!(&**rl, SubnetEntry::Exists { .. }) {
let subnet = Subnet::new(ip.into(), prefix_size as u8)
.expect("Routing table contains valid subnets");
return Some((
subnet,
RoutingTableIterMutEntry {
writer: self.write_guard,
store: Arc::clone(rl),
subnet,
expired_route_entry_sink: self.expired_route_entry_sink.clone(),
cancellation_token: self.cancel_token.clone(),
},
));
};
}
None
}
}
/// A smart pointer giving mutable access to a [`RouteList`].
pub struct RoutingTableIterMutEntry<'a, 'b> {
writer:
&'b mut MutexGuard<'a, left_right::WriteHandle<RoutingTableInner, RoutingTableOplogEntry>>,
/// Owned copy of the RouteList, this is populated once mutable access the the RouteList has
/// been requested.
store: Arc<SubnetEntry>,
/// The subnet we are writing to.
subnet: Subnet,
expired_route_entry_sink: mpsc::Sender<RouteKey>,
cancellation_token: CancellationToken,
}
impl RoutingTableIterMutEntry<'_, '_> {
/// Updates the routes for this entry
pub fn update_routes<F: FnMut(&mut RouteList, &mpsc::Sender<RouteKey>, &CancellationToken)>(
&mut self,
mut op: F,
) {
let mut delete = false;
if let SubnetEntry::Exists { list } = &*self.store {
list.rcu(|rl| {
let mut new_val = rl.clone();
let v = Arc::make_mut(&mut new_val);
op(v, &self.expired_route_entry_sink, &self.cancellation_token);
delete = v.is_empty();
new_val
});
if delete {
trace!(subnet = %self.subnet, "Queue subnet for deletion since route list is now empty");
self.writer
.append(RoutingTableOplogEntry::Delete(self.subnet));
}
}
}
}

View File

@@ -0,0 +1,35 @@
use tokio::time::Instant;
use crate::subnet::Subnet;
/// Information about a [`subnet`](Subnet) which is currently marked as NoRoute.
#[derive(Debug, Clone, Copy)]
pub struct NoRouteSubnet {
/// The subnet which has no route.
subnet: Subnet,
/// Time at which the entry expires. After this timeout expires, the entry is removed and a new
/// query can be performed.
entry_expires: Instant,
}
impl NoRouteSubnet {
/// Create a new `NoRouteSubnet` for the given [`subnet`](Subnet), expiring at the provided
/// [`time`](Instant).
pub fn new(subnet: Subnet, entry_expires: Instant) -> Self {
Self {
subnet,
entry_expires,
}
}
/// The [`subnet`](Subnet) for which there is no route.
pub fn subnet(&self) -> Subnet {
self.subnet
}
/// The moment this entry expires. Once this timeout expires, a new query can be launched for
/// route discovery for this [`subnet`](Subnet).
pub fn entry_expires(&self) -> Instant {
self.entry_expires
}
}

View File

@@ -0,0 +1,35 @@
use tokio::time::Instant;
use crate::subnet::Subnet;
/// Information about a [`subnet`](Subnet) which is currently in the queried state
#[derive(Debug, Clone, Copy)]
pub struct QueriedSubnet {
/// The subnet which was queried.
subnet: Subnet,
/// Time at which the query expires. If no feasible updates come in before this, the subnet is
/// marked as no route temporarily.
query_expires: Instant,
}
impl QueriedSubnet {
/// Create a new `QueriedSubnet` for the given [`subnet`](Subnet), expiring at the provided
/// [`time`](Instant).
pub fn new(subnet: Subnet, query_expires: Instant) -> Self {
Self {
subnet,
query_expires,
}
}
/// The [`subnet`](Subnet) being queried.
pub fn subnet(&self) -> Subnet {
self.subnet
}
/// The moment this query expires. If no route is discovered before this, the [`subnet`](Subnet)
/// is marked as no route temporarily.
pub fn query_expires(&self) -> Instant {
self.query_expires
}
}

View File

@@ -0,0 +1,120 @@
use tokio::time::Instant;
use crate::{
metric::Metric, peer::Peer, router_id::RouterId, sequence_number::SeqNo,
source_table::SourceKey,
};
/// RouteEntry holds all relevant information about a specific route. Since this includes the next
/// hop, a single subnet can have multiple route entries.
#[derive(Clone)]
pub struct RouteEntry {
source: SourceKey,
neighbour: Peer,
metric: Metric,
seqno: SeqNo,
selected: bool,
expires: Instant,
}
impl RouteEntry {
/// Create a new `RouteEntry` with the provided values.
pub fn new(
source: SourceKey,
neighbour: Peer,
metric: Metric,
seqno: SeqNo,
selected: bool,
expires: Instant,
) -> Self {
Self {
source,
neighbour,
metric,
seqno,
selected,
expires,
}
}
/// Return the [`SourceKey`] for this `RouteEntry`.
pub fn source(&self) -> SourceKey {
self.source
}
/// Return the [`neighbour`](Peer) used as next hop for this `RouteEntry`.
pub fn neighbour(&self) -> &Peer {
&self.neighbour
}
/// Return the [`Metric`] of this `RouteEntry`.
pub fn metric(&self) -> Metric {
self.metric
}
/// Return the [`sequence number`](SeqNo) for the `RouteEntry`.
pub fn seqno(&self) -> SeqNo {
self.seqno
}
/// Return if this [`RouteEntry`] is selected.
pub fn selected(&self) -> bool {
self.selected
}
/// Return the [`Instant`] when this `RouteEntry` expires if it doesn't get updated before
/// then.
pub fn expires(&self) -> Instant {
self.expires
}
/// Set the [`SourceKey`] for this `RouteEntry`.
pub fn set_source(&mut self, source: SourceKey) {
self.source = source;
}
/// Set the [`RouterId`] for this `RouteEntry`.
pub fn set_router_id(&mut self, router_id: RouterId) {
self.source.set_router_id(router_id)
}
/// Sets the [`neighbour`](Peer) for this `RouteEntry`.
pub fn set_neighbour(&mut self, neighbour: Peer) {
self.neighbour = neighbour;
}
/// Sets the [`Metric`] for this `RouteEntry`.
pub fn set_metric(&mut self, metric: Metric) {
self.metric = metric;
}
/// Sets the [`sequence number`](SeqNo) for this `RouteEntry`.
pub fn set_seqno(&mut self, seqno: SeqNo) {
self.seqno = seqno;
}
/// Sets if this `RouteEntry` is the selected route for the associated
/// [`Subnet`](crate::subnet::Subnet).
pub fn set_selected(&mut self, selected: bool) {
self.selected = selected;
}
/// Sets the expiration time for this [`RouteEntry`].
pub(super) fn set_expires(&mut self, expires: Instant) {
self.expires = expires;
}
}
// Manual Debug implementation since SharedSecret is explicitly not Debug
impl std::fmt::Debug for RouteEntry {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
f.debug_struct("RouteEntry")
.field("source", &self.source)
.field("neighbour", &self.neighbour)
.field("metric", &self.metric)
.field("seqno", &self.seqno)
.field("selected", &self.selected)
.field("expires", &self.expires)
.finish()
}
}

View File

@@ -0,0 +1,38 @@
use crate::{peer::Peer, subnet::Subnet};
/// RouteKey uniquely defines a route via a peer.
#[derive(Debug, Clone, PartialEq)]
pub struct RouteKey {
subnet: Subnet,
neighbour: Peer,
}
impl RouteKey {
/// Creates a new `RouteKey` for the given [`Subnet`] and [`neighbour`](Peer).
#[inline]
pub fn new(subnet: Subnet, neighbour: Peer) -> Self {
Self { subnet, neighbour }
}
/// Get's the [`Subnet`] identified by this `RouteKey`.
#[inline]
pub fn subnet(&self) -> Subnet {
self.subnet
}
/// Gets the [`neighbour`](Peer) identified by this `RouteKey`.
#[inline]
pub fn neighbour(&self) -> &Peer {
&self.neighbour
}
}
impl std::fmt::Display for RouteKey {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
f.write_fmt(format_args!(
"{} via {}",
self.subnet,
self.neighbour.connection_identifier()
))
}
}

View File

@@ -0,0 +1,201 @@
use std::{
ops::{Deref, DerefMut, Index, IndexMut},
sync::Arc,
};
use tokio::sync::mpsc;
use tokio_util::sync::CancellationToken;
use tracing::{debug, error};
use crate::{crypto::SharedSecret, peer::Peer, task::AbortHandle};
use super::{RouteEntry, RouteKey};
/// The RouteList holds all routes for a specific subnet.
// By convention, if a route is selected, it will always be at index 0 in the list.
#[derive(Clone)]
pub struct RouteList {
list: Vec<(Arc<AbortHandle>, RouteEntry)>,
shared_secret: SharedSecret,
}
impl RouteList {
/// Create a new empty RouteList
pub(crate) fn new(shared_secret: SharedSecret) -> Self {
Self {
list: Vec::new(),
shared_secret,
}
}
/// Returns the [`SharedSecret`] used for encryption of packets to and from the associated
/// [`Subnet`].
#[inline]
pub fn shared_secret(&self) -> &SharedSecret {
&self.shared_secret
}
/// Checks if there are any actual routes in the list.
#[inline]
pub fn is_empty(&self) -> bool {
self.list.is_empty()
}
/// Returns the selected route for the [`Subnet`] this is the `RouteList` for, if one exists.
pub fn selected(&self) -> Option<&RouteEntry> {
self.list
.first()
.map(|(_, re)| re)
.and_then(|re| if re.selected() { Some(re) } else { None })
}
/// Returns an iterator over the `RouteList`.
///
/// The iterator yields all [`route entries`](RouteEntry) in the list.
pub fn iter(&self) -> RouteListIter {
RouteListIter::new(self)
}
/// Returns an iterator over the `RouteList` yielding mutable access to the elements.
///
/// The iterator yields all [`route entries`](RouteEntry) in the list.
pub fn iter_mut(&mut self) -> impl Iterator<Item = RouteGuard> {
self.list.iter_mut().map(|item| RouteGuard { item })
}
/// Removes a [`RouteEntry`] from the `RouteList`.
///
/// This does nothing if the neighbour does not exist.
pub fn remove(&mut self, neighbour: &Peer) {
let Some(pos) = self
.list
.iter()
.position(|re| re.1.neighbour() == neighbour)
else {
return;
};
let old = self.list.swap_remove(pos);
old.0.abort();
}
/// Swaps the position of 2 `RouteEntry`s in the route list.
pub fn swap(&mut self, first: usize, second: usize) {
self.list.swap(first, second)
}
pub fn get_mut(&mut self, index: usize) -> Option<&mut RouteEntry> {
self.list.get_mut(index).map(|(_, re)| re)
}
/// Insert a new [`RouteEntry`] in the `RouteList`.
pub fn insert(
&mut self,
re: RouteEntry,
expired_route_entry_sink: mpsc::Sender<RouteKey>,
cancellation_token: CancellationToken,
) {
let expiration = re.expires();
let rk = RouteKey::new(re.source().subnet(), re.neighbour().clone());
let abort_handle = Arc::new(
tokio::spawn(async move {
tokio::select! {
_ = cancellation_token.cancelled() => {}
_ = tokio::time::sleep_until(expiration) => {
debug!(route_key = %rk, "Expired route entry for route key");
if let Err(e) = expired_route_entry_sink.send(rk).await {
error!(route_key = %e.0, "Failed to send expired route key on cleanup channel");
}
}
}
})
.abort_handle().into(),
);
self.list.push((abort_handle, re));
}
}
pub struct RouteGuard<'a> {
item: &'a mut (Arc<AbortHandle>, RouteEntry),
}
impl Deref for RouteGuard<'_> {
type Target = RouteEntry;
fn deref(&self) -> &Self::Target {
&self.item.1
}
}
impl DerefMut for RouteGuard<'_> {
fn deref_mut(&mut self) -> &mut Self::Target {
&mut self.item.1
}
}
impl RouteGuard<'_> {
pub fn set_expires(
&mut self,
expires: tokio::time::Instant,
expired_route_entry_sink: mpsc::Sender<RouteKey>,
cancellation_token: CancellationToken,
) {
let re = &mut self.item.1;
re.set_expires(expires);
let expiration = re.expires();
let rk = RouteKey::new(re.source().subnet(), re.neighbour().clone());
let abort_handle = Arc::new(
tokio::spawn(async move {
tokio::select! {
_ = cancellation_token.cancelled() => {}
_ = tokio::time::sleep_until(expiration) => {
debug!(route_key = %rk, "Expired route entry for route key");
if let Err(e) = expired_route_entry_sink.send(rk).await {
error!(route_key = %e.0, "Failed to send expired route key on cleanup channel");
}
}
}
})
.abort_handle().into(),
);
self.item.0.abort();
self.item.0 = abort_handle;
}
}
impl Index<usize> for RouteList {
type Output = RouteEntry;
fn index(&self, index: usize) -> &Self::Output {
&self.list[index].1
}
}
impl IndexMut<usize> for RouteList {
fn index_mut(&mut self, index: usize) -> &mut Self::Output {
&mut self.list[index].1
}
}
pub struct RouteListIter<'a> {
route_list: &'a RouteList,
idx: usize,
}
impl<'a> RouteListIter<'a> {
/// Create a new `RouteListIter` which will iterate over the given [`RouteList`].
fn new(route_list: &'a RouteList) -> Self {
Self { route_list, idx: 0 }
}
}
impl<'a> Iterator for RouteListIter<'a> {
type Item = &'a RouteEntry;
fn next(&mut self) -> Option<Self::Item> {
self.idx += 1;
self.route_list.list.get(self.idx - 1).map(|(_, re)| re)
}
}

View File

@@ -0,0 +1,16 @@
use arc_swap::ArcSwap;
use super::RouteList;
/// An entry for a [Subnet](crate::subnet::Subnet) in the routing table.
#[allow(dead_code)]
pub enum SubnetEntry {
/// Routes for the given subnet exist
Exists { list: ArcSwap<RouteList> },
/// Routes are being queried from peers for the given subnet, but we haven't gotten a response
/// yet
Queried { query_timeout: tokio::time::Instant },
/// We queried our peers for the subnet, but we didn't get a valid response in time, so there
/// is for sure no route to the subnet.
NoRoute { expiry: tokio::time::Instant },
}

103
mycelium/src/rr_cache.rs Normal file
View File

@@ -0,0 +1,103 @@
//! This module contains a cache implementation for route requests
use std::{
net::{IpAddr, Ipv6Addr},
sync::Arc,
};
use dashmap::DashMap;
use tokio::time::{Duration, Instant};
use tracing::trace;
use crate::{babel::RouteRequest, peer::Peer, subnet::Subnet, task::AbortHandle};
/// Clean the route request cache every 5 seconds
const CACHE_CLEANING_INTERVAL: Duration = Duration::from_secs(5);
/// IP used for the [`Subnet`] in the cache in case there is no prefix specified.
const GLOBAL_SUBNET_IP: IpAddr = IpAddr::V6(Ipv6Addr::new(0, 0, 0, 0, 0, 0, 0, 0));
/// Prefix size to use for the [`Subnet`] in case there is no prefix specified.
const GLOBAL_SUBNET_PREFIX_SIZE: u8 = 0;
/// A self cleaning cache for route requests.
#[derive(Clone)]
pub struct RouteRequestCache {
/// The actual cache, mapping an instance of a route request to the peers which we've sent this
/// to.
cache: Arc<DashMap<Subnet, RouteRequestInfo, ahash::RandomState>>,
_cleanup_task: Arc<AbortHandle>,
}
struct RouteRequestInfo {
/// The lowest generation we've forwarded.
generation: u8,
/// Peers which we've sent this route request to already.
receivers: Vec<Peer>,
/// The moment we've sent this route request
sent: Instant,
}
impl RouteRequestCache {
/// Create a new cache which cleans entries which are older than the given expiration.
///
/// The cache cleaning is done periodically, so entries might live slightly longer than the
/// allowed expiration.
pub fn new(expiration: Duration) -> Self {
let cache = Arc::new(DashMap::with_hasher(ahash::RandomState::new()));
let _cleanup_task = Arc::new(
tokio::spawn({
let cache = cache.clone();
async move {
loop {
tokio::time::sleep(CACHE_CLEANING_INTERVAL).await;
trace!("Cleaning route request cache");
cache.retain(|subnet, info: &mut RouteRequestInfo| {
if info.sent.elapsed() < expiration {
false
} else {
trace!(%subnet, "Removing exired route request from cache");
true
}
});
}
}
})
.abort_handle()
.into(),
);
Self {
cache,
_cleanup_task,
}
}
/// Record a route request which has been sent to peers.
pub fn sent_route_request(&self, rr: RouteRequest, receivers: Vec<Peer>) {
let subnet = rr.prefix().unwrap_or(
Subnet::new(GLOBAL_SUBNET_IP, GLOBAL_SUBNET_PREFIX_SIZE)
.expect("Static global IPv6 subnet is valid; qed"),
);
let generation = rr.generation();
let rri = RouteRequestInfo {
generation,
receivers,
sent: Instant::now(),
};
self.cache.insert(subnet, rri);
}
/// Get cached info about a route request for a subnet, if it exists.
pub fn info(&self, subnet: Subnet) -> Option<(u8, Vec<Peer>)> {
self.cache
.get(&subnet)
.map(|rri| (rri.generation, rri.receivers.clone()))
}
}

176
mycelium/src/seqno_cache.rs Normal file
View File

@@ -0,0 +1,176 @@
//! The seqno request cache keeps track of seqno requests sent by the node. This allows us to drop
//! duplicate requests, and to notify the source of requests (if it wasn't the local node) about
//! relevant updates.
use std::{
sync::Arc,
time::{Duration, Instant},
};
use dashmap::DashMap;
use tokio::time::MissedTickBehavior;
use tracing::{debug, trace};
use crate::{peer::Peer, router_id::RouterId, sequence_number::SeqNo, subnet::Subnet};
/// The amount of time to remember a seqno request (since it was first seen), before we remove it
/// (assuming it was not removed manually before that).
const SEQNO_DEDUP_TTL: Duration = Duration::from_secs(60);
/// A sequence number request, either forwarded or originated by the local node.
#[derive(Clone, Copy, PartialEq, Eq, Hash)]
pub struct SeqnoRequestCacheKey {
pub router_id: RouterId,
pub subnet: Subnet,
pub seqno: SeqNo,
}
impl std::fmt::Display for SeqnoRequestCacheKey {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
write!(
f,
"seqno {} for {} from {}",
self.seqno, self.subnet, self.router_id
)
}
}
/// Information retained for sequence number requests we've sent.
struct SeqnoForwardInfo {
/// Which peers have asked us to forward this seqno request.
sources: Vec<Peer>,
/// Which peers have we sent this request to.
targets: Vec<Peer>,
/// Time at which we first forwarded the requets.
first_sent: Instant,
/// When did we last sent a seqno request.
last_sent: Instant,
}
/// A cache for outbound seqno requests. Entries in the cache are automatically removed after a
/// certain amount of time. The cache does not account for the source table. That is, if the
/// requested seqno is smaller, it might pass the cache, but should have been blocked earlier by
/// the source table check. As such, this cache should be the last step in deciding if a seqno
/// request is forwarded.
#[derive(Clone)]
pub struct SeqnoCache {
/// Actual cache wrapped in an Arc to make it sharaeble.
cache: Arc<DashMap<SeqnoRequestCacheKey, SeqnoForwardInfo, ahash::RandomState>>,
}
impl SeqnoCache {
/// Create a new [`SeqnoCache`].
pub fn new() -> Self {
trace!(capacity = 0, "Creating new seqno cache");
let cache = Arc::new(DashMap::with_hasher_and_shard_amount(
ahash::RandomState::new(),
// This number has been chosen completely at random
1024,
));
let sc = Self { cache };
// Spawn background cleanup task.
tokio::spawn(sc.clone().sweep_entries());
sc
}
/// Record a forwarded seqno request to a given target. Also keep track of the origin of the
/// request. If the local node generated the request, source must be [`None`]
pub fn forward(&self, request: SeqnoRequestCacheKey, target: Peer, source: Option<Peer>) {
let mut info = self.cache.entry(request).or_default();
info.last_sent = Instant::now();
if !info.targets.contains(&target) {
info.targets.push(target);
} else {
debug!(
seqno_request = %request,
"Already sent seqno request to target {}",
target.connection_identifier()
);
}
if let Some(source) = source {
if !info.sources.contains(&source) {
info.sources.push(source);
} else {
debug!(seqno_request = %request, "Peer {} is requesting the same seqno again", source.connection_identifier());
}
}
}
/// Get a list of all peers which we've already sent the given seqno request to, as well as
/// when we've last sent a request.
pub fn info(&self, request: &SeqnoRequestCacheKey) -> Option<(Instant, Vec<Peer>)> {
self.cache
.get(request)
.map(|info| (info.last_sent, info.targets.clone()))
}
/// Removes forwarding info from the seqno cache. If forwarding info is available, the source
/// peers (peers which requested us to forward this request) are returned.
// TODO: cleanup if needed
#[allow(dead_code)]
pub fn remove(&self, request: &SeqnoRequestCacheKey) -> Option<Vec<Peer>> {
self.cache.remove(request).map(|(_, info)| info.sources)
}
/// Get forwarding info from the seqno cache. If forwarding info is available, the source
/// peers (peers which requested us to forward this request) are returned.
// TODO: cleanup if needed
#[allow(dead_code)]
pub fn get(&self, request: &SeqnoRequestCacheKey) -> Option<Vec<Peer>> {
self.cache.get(request).map(|info| info.sources.clone())
}
/// Periodic task to clear old entries for which no reply came in.
async fn sweep_entries(self) {
let mut interval = tokio::time::interval(SEQNO_DEDUP_TTL);
interval.set_missed_tick_behavior(MissedTickBehavior::Skip);
loop {
interval.tick().await;
debug!("Cleaning up expired seqno requests from seqno cache");
let prev_entries = self.cache.len();
let prev_cap = self.cache.capacity();
self.cache
.retain(|_, info| info.first_sent.elapsed() <= SEQNO_DEDUP_TTL);
self.cache.shrink_to_fit();
debug!(
cleaned_entries = prev_entries - self.cache.len(),
removed_capacity = prev_cap - self.cache.capacity(),
"Cleaned up stale seqno request cache entries"
);
}
}
}
impl Default for SeqnoCache {
fn default() -> Self {
Self::new()
}
}
impl Default for SeqnoForwardInfo {
fn default() -> Self {
Self {
sources: vec![],
targets: vec![],
first_sent: Instant::now(),
last_sent: Instant::now(),
}
}
}
impl std::fmt::Debug for SeqnoRequestCacheKey {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
f.debug_struct("SeqnoRequestCacheKey")
.field("router_id", &self.router_id.to_string())
.field("subnet", &self.subnet.to_string())
.field("seqno", &self.seqno.to_string())
.finish()
}
}

View File

@@ -0,0 +1,153 @@
//! Dedicated logic for
//! [sequence numbers](https://datatracker.ietf.org/doc/html/rfc8966#name-solving-starvation-sequenci).
use core::fmt;
use core::ops::{Add, AddAssign};
/// This value is compared against when deciding if a `SeqNo` is larger or smaller, [as defined in
/// the babel rfc](https://datatracker.ietf.org/doc/html/rfc8966#section-3.2.1).
const SEQNO_COMPARE_TRESHOLD: u16 = 32_768;
/// A sequence number on a route.
#[derive(Default, Debug, Clone, Copy, PartialEq, Eq, Hash)]
pub struct SeqNo(u16);
impl SeqNo {
/// Create a new `SeqNo` with the default value.
pub fn new() -> Self {
Self::default()
}
/// Custom PartialOrd implementation as defined in [the babel rfc](https://datatracker.ietf.org/doc/html/rfc8966#section-3.2.1).
/// Note that we don't implement the [`PartialOrd`](std::cmd::PartialOrd) trait, as the contract on
/// that trait specifically defines that it is transitive, which is clearly not the case here.
///
/// There is a quirk in this equality comparison where values which are exactly 32_768 apart,
/// will result in false in either way of ordering the arguments, which is counterintuitive to
/// our understanding that a < b generally implies !(b < a).
pub fn lt(&self, other: &Self) -> bool {
if self.0 == other.0 {
false
} else {
other.0.wrapping_sub(self.0) < SEQNO_COMPARE_TRESHOLD
}
}
/// Custom PartialOrd implementation as defined in [the babel rfc](https://datatracker.ietf.org/doc/html/rfc8966#section-3.2.1).
/// Note that we don't implement the [`PartialOrd`](std::cmd::PartialOrd) trait, as the contract on
/// that trait specifically defines that it is transitive, which is clearly not the case here.
///
/// There is a quirk in this equality comparison where values which are exactly 32_768 apart,
/// will result in false in either way of ordering the arguments, which is counterintuitive to
/// our understanding that a < b generally implies !(b < a).
pub fn gt(&self, other: &Self) -> bool {
if self.0 == other.0 {
false
} else {
other.0.wrapping_sub(self.0) > SEQNO_COMPARE_TRESHOLD
}
}
}
impl fmt::Display for SeqNo {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
f.write_fmt(format_args!("{}", self.0))
}
}
impl From<u16> for SeqNo {
fn from(value: u16) -> Self {
SeqNo(value)
}
}
impl From<SeqNo> for u16 {
fn from(value: SeqNo) -> Self {
value.0
}
}
impl Add<u16> for SeqNo {
type Output = Self;
fn add(self, rhs: u16) -> Self::Output {
SeqNo(self.0.wrapping_add(rhs))
}
}
impl AddAssign<u16> for SeqNo {
fn add_assign(&mut self, rhs: u16) {
*self = SeqNo(self.0.wrapping_add(rhs))
}
}
#[cfg(test)]
mod tests {
use super::SeqNo;
#[test]
fn cmp_eq_seqno() {
let s1 = SeqNo::from(1);
let s2 = SeqNo::from(1);
assert_eq!(s1, s2);
let s1 = SeqNo::from(10_000);
let s2 = SeqNo::from(10_000);
assert_eq!(s1, s2);
}
#[test]
fn cmp_small_seqno_increase() {
let s1 = SeqNo::from(1);
let s2 = SeqNo::from(2);
assert!(s1.lt(&s2));
assert!(!s2.lt(&s1));
assert!(s2.gt(&s1));
assert!(!s1.gt(&s2));
let s1 = SeqNo::from(3);
let s2 = SeqNo::from(30_000);
assert!(s1.lt(&s2));
assert!(!s2.lt(&s1));
assert!(s2.gt(&s1));
assert!(!s1.gt(&s2));
}
#[test]
fn cmp_big_seqno_increase() {
let s1 = SeqNo::from(0);
let s2 = SeqNo::from(32_767);
assert!(s1.lt(&s2));
assert!(!s2.lt(&s1));
assert!(s2.gt(&s1));
assert!(!s1.gt(&s2));
// Test equality quirk at cutoff point.
let s1 = SeqNo::from(0);
let s2 = SeqNo::from(32_768);
assert!(!s1.lt(&s2));
assert!(!s2.lt(&s1));
assert!(!s2.gt(&s1));
assert!(!s1.gt(&s2));
let s1 = SeqNo::from(0);
let s2 = SeqNo::from(32_769);
assert!(!s1.lt(&s2));
assert!(s2.lt(&s1));
assert!(!s2.gt(&s1));
assert!(s1.gt(&s2));
let s1 = SeqNo::from(6);
let s2 = SeqNo::from(60_000);
assert!(!s1.lt(&s2));
assert!(s2.lt(&s1));
assert!(!s2.gt(&s1));
assert!(s1.gt(&s2));
}
}

View File

@@ -0,0 +1,489 @@
use core::fmt;
use std::{collections::HashMap, time::Duration};
use tokio::{sync::mpsc, task::JoinHandle};
use tracing::error;
use crate::{
babel, metric::Metric, router_id::RouterId, routing_table::RouteEntry, sequence_number::SeqNo,
subnet::Subnet,
};
/// Duration after which a source entry is deleted if it is not updated.
const SOURCE_HOLD_DURATION: Duration = Duration::from_secs(60 * 30);
#[derive(Debug, Clone, PartialEq, Eq, Hash, Copy)]
pub struct SourceKey {
subnet: Subnet,
router_id: RouterId,
}
#[derive(Debug, Clone, Copy)]
pub struct FeasibilityDistance {
metric: Metric,
seqno: SeqNo,
}
#[derive(Debug)]
pub struct SourceTable {
table: HashMap<SourceKey, (JoinHandle<()>, FeasibilityDistance)>,
}
impl FeasibilityDistance {
pub fn new(metric: Metric, seqno: SeqNo) -> Self {
FeasibilityDistance { metric, seqno }
}
/// Returns the metric for this `FeasibilityDistance`.
pub const fn metric(&self) -> Metric {
self.metric
}
/// Returns the sequence number for this `FeasibilityDistance`.
pub const fn seqno(&self) -> SeqNo {
self.seqno
}
}
impl SourceKey {
/// Create a new `SourceKey`.
pub const fn new(subnet: Subnet, router_id: RouterId) -> Self {
Self { subnet, router_id }
}
/// Returns the [`RouterId`] for this `SourceKey`.
pub const fn router_id(&self) -> RouterId {
self.router_id
}
/// Returns the [`Subnet`] for this `SourceKey`.
pub const fn subnet(&self) -> Subnet {
self.subnet
}
/// Updates the [`RouterId`] of this `SourceKey`
pub fn set_router_id(&mut self, router_id: RouterId) {
self.router_id = router_id
}
}
impl SourceTable {
pub fn new() -> Self {
Self {
table: HashMap::new(),
}
}
pub fn insert(
&mut self,
key: SourceKey,
feas_dist: FeasibilityDistance,
sink: mpsc::Sender<SourceKey>,
) {
let expiration_handle = tokio::spawn(async move {
tokio::time::sleep(SOURCE_HOLD_DURATION).await;
if let Err(e) = sink.send(key).await {
error!("Failed to notify router of expired source key {e}");
}
});
// Abort the old task if present.
if let Some((old_timeout, _)) = self.table.insert(key, (expiration_handle, feas_dist)) {
old_timeout.abort();
}
}
/// Remove an entry from the source table.
pub fn remove(&mut self, key: &SourceKey) {
if let Some((old_timeout, _)) = self.table.remove(key) {
old_timeout.abort();
};
}
/// Resets the garbage collection timer for a given source key.
///
/// Does nothing if the source key is not present.
pub fn reset_timer(&mut self, key: SourceKey, sink: mpsc::Sender<SourceKey>) {
self.table
.entry(key)
.and_modify(|(old_expiration_handle, _)| {
// First cancel the existing task
old_expiration_handle.abort();
// Then set the new one
*old_expiration_handle = tokio::spawn(async move {
tokio::time::sleep(SOURCE_HOLD_DURATION).await;
if let Err(e) = sink.send(key).await {
error!("Failed to notify router of expired source key {e}");
}
});
});
}
/// Get the [`FeasibilityDistance`] currently associated with the [`SourceKey`].
pub fn get(&self, key: &SourceKey) -> Option<&FeasibilityDistance> {
self.table.get(key).map(|(_, v)| v)
}
/// Indicates if an update is feasible in the context of the current `SoureTable`.
pub fn is_update_feasible(&self, update: &babel::Update) -> bool {
// Before an update is accepted it should be checked against the feasbility condition
// If an entry in the source table with the same source key exists, we perform the feasbility check
// If no entry exists yet, the update is accepted as there is no better alternative available (yet)
let source_key = SourceKey::new(update.subnet(), update.router_id());
match self.get(&source_key) {
Some(entry) => {
(update.seqno().gt(&entry.seqno()))
|| (update.seqno() == entry.seqno() && update.metric() < entry.metric())
|| update.metric().is_infinite()
}
None => true,
}
}
/// Indicates if a [`RouteEntry`] is feasible according to the `SourceTable`.
pub fn route_feasible(&self, route: &RouteEntry) -> bool {
match self.get(&route.source()) {
Some(fd) => {
(route.seqno().gt(&fd.seqno))
|| (route.seqno() == fd.seqno && route.metric() < fd.metric)
|| route.metric().is_infinite()
}
None => true,
}
}
}
impl fmt::Display for SourceKey {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
f.write_fmt(format_args!(
"{} advertised by {}",
self.subnet, self.router_id
))
}
}
#[cfg(test)]
mod tests {
use tokio::sync::mpsc;
use crate::{
babel,
crypto::SecretKey,
metric::Metric,
peer::Peer,
router_id::RouterId,
routing_table::RouteEntry,
sequence_number::SeqNo,
source_table::{FeasibilityDistance, SourceKey, SourceTable},
subnet::Subnet,
};
use std::{
net::Ipv6Addr,
sync::{atomic::AtomicU64, Arc},
time::Duration,
};
/// A retraction is always considered to be feasible.
#[tokio::test]
async fn retraction_update_is_feasible() {
let (sink, _) = tokio::sync::mpsc::channel(1);
let sk = SecretKey::new();
let pk = (&sk).into();
let sn = Subnet::new(Ipv6Addr::new(0x400, 0, 0, 0, 0, 0, 0, 1).into(), 64)
.expect("Valid subnet in test case");
let rid = RouterId::new(pk);
let mut st = SourceTable::new();
st.insert(
SourceKey::new(sn, rid),
FeasibilityDistance::new(Metric::new(10), SeqNo::from(1)),
sink,
);
let update = babel::Update::new(
Duration::from_secs(60),
SeqNo::from(0),
Metric::infinite(),
sn,
rid,
);
assert!(st.is_update_feasible(&update));
}
/// An update with a smaller metric but with the same seqno is feasible.
#[tokio::test]
async fn smaller_metric_update_is_feasible() {
let (sink, _) = tokio::sync::mpsc::channel(1);
let sk = SecretKey::new();
let pk = (&sk).into();
let sn = Subnet::new(Ipv6Addr::new(0x400, 0, 0, 0, 0, 0, 0, 1).into(), 64)
.expect("Valid subnet in test case");
let rid = RouterId::new(pk);
let mut st = SourceTable::new();
st.insert(
SourceKey::new(sn, rid),
FeasibilityDistance::new(Metric::new(10), SeqNo::from(1)),
sink,
);
let update = babel::Update::new(
Duration::from_secs(60),
SeqNo::from(1),
Metric::from(9),
sn,
rid,
);
assert!(st.is_update_feasible(&update));
}
/// An update with the same metric and seqno is not feasible.
#[tokio::test]
async fn equal_metric_update_is_unfeasible() {
let (sink, _) = tokio::sync::mpsc::channel(1);
let sk = SecretKey::new();
let pk = (&sk).into();
let sn = Subnet::new(Ipv6Addr::new(0x400, 0, 0, 0, 0, 0, 0, 1).into(), 64)
.expect("Valid subnet in test case");
let rid = RouterId::new(pk);
let mut st = SourceTable::new();
st.insert(
SourceKey::new(sn, rid),
FeasibilityDistance::new(Metric::new(10), SeqNo::from(1)),
sink,
);
let update = babel::Update::new(
Duration::from_secs(60),
SeqNo::from(1),
Metric::from(10),
sn,
rid,
);
assert!(!st.is_update_feasible(&update));
}
/// An update with a larger metric and the same seqno is not feasible.
#[tokio::test]
async fn larger_metric_update_is_unfeasible() {
let (sink, _) = tokio::sync::mpsc::channel(1);
let sk = SecretKey::new();
let pk = (&sk).into();
let sn = Subnet::new(Ipv6Addr::new(0x400, 0, 0, 0, 0, 0, 0, 1).into(), 64)
.expect("Valid subnet in test case");
let rid = RouterId::new(pk);
let mut st = SourceTable::new();
st.insert(
SourceKey::new(sn, rid),
FeasibilityDistance::new(Metric::new(10), SeqNo::from(1)),
sink,
);
let update = babel::Update::new(
Duration::from_secs(60),
SeqNo::from(1),
Metric::from(11),
sn,
rid,
);
assert!(!st.is_update_feasible(&update));
}
/// An update with a lower seqno is not feasible.
#[tokio::test]
async fn lower_seqno_update_is_unfeasible() {
let (sink, _) = tokio::sync::mpsc::channel(1);
let sk = SecretKey::new();
let pk = (&sk).into();
let sn = Subnet::new(Ipv6Addr::new(0x400, 0, 0, 0, 0, 0, 0, 1).into(), 64)
.expect("Valid subnet in test case");
let rid = RouterId::new(pk);
let mut st = SourceTable::new();
st.insert(
SourceKey::new(sn, rid),
FeasibilityDistance::new(Metric::new(10), SeqNo::from(1)),
sink,
);
let update = babel::Update::new(
Duration::from_secs(60),
SeqNo::from(0),
Metric::from(1),
sn,
rid,
);
assert!(!st.is_update_feasible(&update));
}
/// An update with a higher seqno is feasible.
#[tokio::test]
async fn higher_seqno_update_is_feasible() {
let (sink, _) = tokio::sync::mpsc::channel(1);
let sk = SecretKey::new();
let pk = (&sk).into();
let sn = Subnet::new(Ipv6Addr::new(0x400, 0, 0, 0, 0, 0, 0, 1).into(), 64)
.expect("Valid subnet in test case");
let rid = RouterId::new(pk);
let mut st = SourceTable::new();
st.insert(
SourceKey::new(sn, rid),
FeasibilityDistance::new(Metric::new(10), SeqNo::from(1)),
sink,
);
let update = babel::Update::new(
Duration::from_secs(60),
SeqNo::from(2),
Metric::from(200),
sn,
rid,
);
assert!(st.is_update_feasible(&update));
}
/// A route with a smaller metric but with the same seqno is feasible.
#[tokio::test]
async fn smaller_metric_route_is_feasible() {
let (sink, _) = tokio::sync::mpsc::channel(1);
let sk = SecretKey::new();
let pk = (&sk).into();
let sn = Subnet::new(Ipv6Addr::new(0x400, 0, 0, 0, 0, 0, 0, 1).into(), 64)
.expect("Valid subnet in test case");
let rid = RouterId::new(pk);
let source_key = SourceKey::new(sn, rid);
let mut st = SourceTable::new();
st.insert(
source_key,
FeasibilityDistance::new(Metric::new(10), SeqNo::from(1)),
sink,
);
let (router_data_tx, _router_data_rx) = mpsc::channel(1);
let (router_control_tx, _router_control_rx) = mpsc::unbounded_channel();
let (dead_peer_sink, _dead_peer_stream) = mpsc::channel(1);
let (con1, _con2) = tokio::io::duplex(1500);
let neighbor = Peer::new(
router_data_tx,
router_control_tx,
con1,
dead_peer_sink,
Arc::new(AtomicU64::new(0)),
Arc::new(AtomicU64::new(0)),
)
.expect("Can create a dummy peer");
let re = RouteEntry::new(
source_key,
neighbor,
Metric::new(9),
SeqNo::from(1),
true,
tokio::time::Instant::now() + Duration::from_secs(60),
);
assert!(st.route_feasible(&re));
}
/// If a route has the same metric as the source table it is not feasible.
#[tokio::test]
async fn equal_metric_route_is_unfeasible() {
let (sink, _) = tokio::sync::mpsc::channel(1);
let sk = SecretKey::new();
let pk = (&sk).into();
let sn = Subnet::new(Ipv6Addr::new(0x400, 0, 0, 0, 0, 0, 0, 1).into(), 64)
.expect("Valid subnet in test case");
let rid = RouterId::new(pk);
let source_key = SourceKey::new(sn, rid);
let mut st = SourceTable::new();
st.insert(
source_key,
FeasibilityDistance::new(Metric::new(10), SeqNo::from(1)),
sink,
);
let (router_data_tx, _router_data_rx) = mpsc::channel(1);
let (router_control_tx, _router_control_rx) = mpsc::unbounded_channel();
let (dead_peer_sink, _dead_peer_stream) = mpsc::channel(1);
let (con1, _con2) = tokio::io::duplex(1500);
let neighbor = Peer::new(
router_data_tx,
router_control_tx,
con1,
dead_peer_sink,
Arc::new(AtomicU64::new(0)),
Arc::new(AtomicU64::new(0)),
)
.expect("Can create a dummy peer");
let re = RouteEntry::new(
source_key,
neighbor,
Metric::new(10),
SeqNo::from(1),
true,
tokio::time::Instant::now() + Duration::from_secs(60),
);
assert!(!st.route_feasible(&re));
}
/// If a route has a higher metric as the source table it is not feasible.
#[tokio::test]
async fn higher_metric_route_is_unfeasible() {
let (sink, _) = tokio::sync::mpsc::channel(1);
let sk = SecretKey::new();
let pk = (&sk).into();
let sn = Subnet::new(Ipv6Addr::new(0x400, 0, 0, 0, 0, 0, 0, 1).into(), 64)
.expect("Valid subnet in test case");
let rid = RouterId::new(pk);
let source_key = SourceKey::new(sn, rid);
let mut st = SourceTable::new();
st.insert(
source_key,
FeasibilityDistance::new(Metric::new(10), SeqNo::from(1)),
sink,
);
let (router_data_tx, _router_data_rx) = mpsc::channel(1);
let (router_control_tx, _router_control_rx) = mpsc::unbounded_channel();
let (dead_peer_sink, _dead_peer_stream) = mpsc::channel(1);
let (con1, _con2) = tokio::io::duplex(1500);
let neighbor = Peer::new(
router_data_tx,
router_control_tx,
con1,
dead_peer_sink,
Arc::new(AtomicU64::new(0)),
Arc::new(AtomicU64::new(0)),
)
.expect("Can create a dummy peer");
let re = RouteEntry::new(
source_key,
neighbor,
Metric::new(11),
SeqNo::from(1),
true,
tokio::time::Instant::now() + Duration::from_secs(60),
);
assert!(!st.route_feasible(&re));
}
}

277
mycelium/src/subnet.rs Normal file
View File

@@ -0,0 +1,277 @@
//! A dedicated subnet module.
//!
//! The standard library only exposes [`IpAddr`], and types related to
//! specific IPv4 and IPv6 addresses. It does not however, expose dedicated types to represent
//! appropriate subnets.
//!
//! This code is not meant to fully support subnets, but rather only the subset as needed by the
//! main application code. As such, this implementation is optimized for the specific use case, and
//! might not be optimal for other uses.
use core::fmt;
use std::{
hash::Hash,
net::{IpAddr, Ipv6Addr},
str::FromStr,
};
use ipnet::IpNet;
/// Representation of a subnet. A subnet can be either IPv4 or IPv6.
#[derive(Debug, Clone, Copy, Eq, PartialOrd, Ord)]
pub struct Subnet {
inner: IpNet,
}
/// An error returned when creating a new [`Subnet`] with an invalid prefix length.
///
/// For IPv4, the max prefix length is 32, and for IPv6 it is 128;
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
pub struct PrefixLenError;
impl Subnet {
/// Create a new `Subnet` from the given [`IpAddr`] and prefix length.
pub fn new(addr: IpAddr, prefix_len: u8) -> Result<Subnet, PrefixLenError> {
Ok(Self {
inner: IpNet::new(addr, prefix_len).map_err(|_| PrefixLenError)?,
})
}
/// Returns the size of the prefix in bits.
pub fn prefix_len(&self) -> u8 {
self.inner.prefix_len()
}
/// Retuns the address in this subnet.
///
/// The returned address is a full IP address, used to construct this `Subnet`.
///
/// # Examples
///
/// ```
/// use mycelium::subnet::Subnet;
/// use std::net::Ipv6Addr;
///
/// let address = Ipv6Addr::new(12,34,56,78,90,0xab,0xcd,0xef).into();
/// let subnet = Subnet::new(address, 64).unwrap();
///
/// assert_eq!(subnet.address(), address);
/// ```
pub fn address(&self) -> IpAddr {
self.inner.addr()
}
/// Checks if this `Subnet` contains the provided `Subnet`, i.e. all addresses of the provided
/// `Subnet` are also part of this `Subnet`
///
/// # Examples
///
/// ```
/// use mycelium::subnet::Subnet;
/// use std::net::Ipv4Addr;
///
/// let global = Subnet::new(Ipv4Addr::new(0,0,0,0).into(), 0).expect("Defined a valid subnet");
/// let local = Subnet::new(Ipv4Addr::new(10,0,0,0).into(), 8).expect("Defined a valid subnet");
///
/// assert!(global.contains_subnet(&local));
/// assert!(!local.contains_subnet(&global));
/// ```
pub fn contains_subnet(&self, other: &Self) -> bool {
self.inner.contains(&other.inner)
}
/// Checks if this `Subnet` contains the provided [`IpAddr`].
///
/// # Examples
///
/// ```
/// use mycelium::subnet::Subnet;
/// use std::net::{Ipv4Addr,Ipv6Addr};
///
/// let ip_1 = Ipv6Addr::new(12,34,56,78,90,0xab,0xcd,0xef).into();
/// let ip_2 = Ipv6Addr::new(90,0xab,0xcd,0xef,12,34,56,78).into();
/// let ip_3 = Ipv4Addr::new(10,1,2,3).into();
/// let subnet = Subnet::new(Ipv6Addr::new(12,34,5,6,7,8,9,0).into(), 32).unwrap();
///
/// assert!(subnet.contains_ip(ip_1));
/// assert!(!subnet.contains_ip(ip_2));
/// assert!(!subnet.contains_ip(ip_3));
/// ```
pub fn contains_ip(&self, ip: IpAddr) -> bool {
self.inner.contains(&ip)
}
/// Returns the network part of the `Subnet`. All non prefix bits are set to 0.
///
/// # Examples
///
/// ```
/// use mycelium::subnet::Subnet;
/// use std::net::{IpAddr, Ipv4Addr,Ipv6Addr};
///
/// let subnet_1 = Subnet::new(Ipv6Addr::new(12,34,56,78,90,0xab,0xcd,0xef).into(),
/// 32).unwrap();
/// let subnet_2 = Subnet::new(Ipv4Addr::new(10,1,2,3).into(), 8).unwrap();
///
/// assert_eq!(subnet_1.network(), IpAddr::V6(Ipv6Addr::new(12,34,0,0,0,0,0,0)));
/// assert_eq!(subnet_2.network(), IpAddr::V4(Ipv4Addr::new(10,0,0,0)));
/// ```
pub fn network(&self) -> IpAddr {
self.inner.network()
}
/// Returns the braodcast address for the subnet.
///
/// # Examples
///
/// ```
/// use mycelium::subnet::Subnet;
/// use std::net::{IpAddr, Ipv4Addr,Ipv6Addr};
///
/// let subnet_1 = Subnet::new(Ipv6Addr::new(12,34,56,78,90,0xab,0xcd,0xef).into(),
/// 32).unwrap();
/// let subnet_2 = Subnet::new(Ipv4Addr::new(10,1,2,3).into(), 8).unwrap();
///
/// assert_eq!(subnet_1.broadcast_addr(),
/// IpAddr::V6(Ipv6Addr::new(12,34,0xffff,0xffff,0xffff,0xffff,0xffff,0xffff)));
/// assert_eq!(subnet_2.broadcast_addr(), IpAddr::V4(Ipv4Addr::new(10,255,255,255)));
/// ```
pub fn broadcast_addr(&self) -> IpAddr {
self.inner.broadcast()
}
/// Returns the netmask of the subnet as an [`IpAddr`].
pub fn mask(&self) -> IpAddr {
self.inner.netmask()
}
}
impl From<Ipv6Addr> for Subnet {
fn from(value: Ipv6Addr) -> Self {
Self::new(value.into(), 128).expect("128 is a valid subnet size for an IPv6 address; qed")
}
}
#[derive(Debug, Clone)]
/// An error indicating a malformed subnet
pub struct SubnetParseError {
_private: (),
}
impl SubnetParseError {
/// Create a new SubnetParseError
fn new() -> Self {
Self { _private: () }
}
}
impl core::fmt::Display for SubnetParseError {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
f.pad("malformed subnet")
}
}
impl std::error::Error for SubnetParseError {}
impl FromStr for Subnet {
type Err = SubnetParseError;
fn from_str(s: &str) -> Result<Self, Self::Err> {
if let Ok(ipnet) = s.parse::<ipnet::IpNet>() {
return Ok(
Subnet::new(ipnet.addr(), ipnet.prefix_len()).expect("Parsed subnet size is valid")
);
}
// Try to parse as an IP address (convert to /32 or /128 subnet)
if let Ok(ip) = s.parse::<std::net::IpAddr>() {
let prefix_len = match ip {
std::net::IpAddr::V4(_) => 32,
std::net::IpAddr::V6(_) => 128,
};
return Ok(Subnet::new(ip, prefix_len).expect("Static subnet sizes are valid"));
}
Err(SubnetParseError::new())
}
}
impl fmt::Display for Subnet {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
write!(f, "{}", self.inner)
}
}
impl PartialEq for Subnet {
fn eq(&self, other: &Self) -> bool {
// Quic check, subnets of different sizes are never equal.
if self.prefix_len() != other.prefix_len() {
return false;
}
// Full check
self.network() == other.network()
}
}
impl Hash for Subnet {
fn hash<H: std::hash::Hasher>(&self, state: &mut H) {
// First write the subnet size
state.write_u8(self.prefix_len());
// Then write the IP of the network. This sets the non prefix bits to 0, so hash values
// will be equal according to the PartialEq rules.
self.network().hash(state)
}
}
impl fmt::Display for PrefixLenError {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
f.write_str("Invalid prefix length for this address")
}
}
impl std::error::Error for PrefixLenError {}
#[cfg(test)]
mod tests {
use std::net::{Ipv4Addr, Ipv6Addr};
use super::Subnet;
#[test]
fn test_subnet_equality() {
let subnet_1 =
Subnet::new(Ipv6Addr::new(12, 23, 34, 45, 56, 67, 78, 89).into(), 64).unwrap();
let subnet_2 =
Subnet::new(Ipv6Addr::new(12, 23, 34, 45, 67, 78, 89, 90).into(), 64).unwrap();
let subnet_3 =
Subnet::new(Ipv6Addr::new(12, 23, 34, 40, 67, 78, 89, 90).into(), 64).unwrap();
let subnet_4 = Subnet::new(Ipv6Addr::new(12, 23, 34, 45, 0, 0, 0, 0).into(), 64).unwrap();
let subnet_5 = Subnet::new(
Ipv6Addr::new(12, 23, 34, 45, 0xffff, 0xffff, 0xffff, 0xffff).into(),
64,
)
.unwrap();
let subnet_6 =
Subnet::new(Ipv6Addr::new(12, 23, 34, 45, 56, 67, 78, 89).into(), 63).unwrap();
assert_eq!(subnet_1, subnet_2);
assert_ne!(subnet_1, subnet_3);
assert_eq!(subnet_1, subnet_4);
assert_eq!(subnet_1, subnet_5);
assert_ne!(subnet_1, subnet_6);
let subnet_1 = Subnet::new(Ipv4Addr::new(10, 1, 2, 3).into(), 24).unwrap();
let subnet_2 = Subnet::new(Ipv4Addr::new(10, 1, 2, 102).into(), 24).unwrap();
let subnet_3 = Subnet::new(Ipv4Addr::new(10, 1, 4, 3).into(), 24).unwrap();
let subnet_4 = Subnet::new(Ipv4Addr::new(10, 1, 2, 0).into(), 24).unwrap();
let subnet_5 = Subnet::new(Ipv4Addr::new(10, 1, 2, 255).into(), 24).unwrap();
let subnet_6 = Subnet::new(Ipv4Addr::new(10, 1, 2, 3).into(), 16).unwrap();
assert_eq!(subnet_1, subnet_2);
assert_ne!(subnet_1, subnet_3);
assert_eq!(subnet_1, subnet_4);
assert_eq!(subnet_1, subnet_5);
assert_ne!(subnet_1, subnet_6);
}
}

29
mycelium/src/task.rs Normal file
View File

@@ -0,0 +1,29 @@
//! This module provides some task abstractions which add custom logic to the default behavior.
/// A handle to a task, which is only used to abort the task. In case this handle is dropped, the
/// task is cancelled automatically.
pub struct AbortHandle(tokio::task::AbortHandle);
impl AbortHandle {
/// Abort the task this `AbortHandle` is referencing. It is safe to call this method multiple
/// times, but only the first call is actually usefull. It is possible for the task to still
/// finish succesfully, even after abort is called.
#[inline]
pub fn abort(&self) {
self.0.abort()
}
}
impl Drop for AbortHandle {
#[inline]
fn drop(&mut self) {
self.0.abort()
}
}
impl From<tokio::task::AbortHandle> for AbortHandle {
#[inline]
fn from(value: tokio::task::AbortHandle) -> Self {
Self(value)
}
}

54
mycelium/src/tun.rs Normal file
View File

@@ -0,0 +1,54 @@
//! The tun module implements a platform independent Tun interface.
#[cfg(any(
target_os = "linux",
all(target_os = "macos", not(feature = "mactunfd")),
target_os = "windows"
))]
use crate::subnet::Subnet;
#[cfg(any(
target_os = "linux",
all(target_os = "macos", not(feature = "mactunfd")),
target_os = "windows"
))]
pub struct TunConfig {
pub name: String,
pub node_subnet: Subnet,
pub route_subnet: Subnet,
}
#[cfg(any(
target_os = "android",
target_os = "ios",
all(target_os = "macos", feature = "mactunfd"),
))]
pub struct TunConfig {
pub tun_fd: i32,
}
#[cfg(target_os = "linux")]
mod linux;
#[cfg(target_os = "linux")]
pub use linux::new;
#[cfg(all(target_os = "macos", not(feature = "mactunfd")))]
mod darwin;
#[cfg(all(target_os = "macos", not(feature = "mactunfd")))]
pub use darwin::new;
#[cfg(target_os = "windows")]
mod windows;
#[cfg(target_os = "windows")]
pub use windows::new;
#[cfg(target_os = "android")]
mod android;
#[cfg(target_os = "android")]
pub use android::new;
#[cfg(any(target_os = "ios", all(target_os = "macos", feature = "mactunfd")))]
mod ios;
#[cfg(any(target_os = "ios", all(target_os = "macos", feature = "mactunfd")))]
pub use ios::new;

108
mycelium/src/tun/android.rs Normal file
View File

@@ -0,0 +1,108 @@
//! android specific tun interface setup.
use std::io::{self};
use futures::{Sink, Stream};
use tokio::{
io::{AsyncReadExt, AsyncWriteExt},
select,
sync::mpsc,
};
use tracing::{error, info};
use crate::crypto::PacketBuffer;
use crate::tun::TunConfig;
// TODO
const LINK_MTU: i32 = 1400;
/// Create a new tun interface and set required routes
///
/// # Panics
///
/// This function will panic if called outside of the context of a tokio runtime.
pub async fn new(
tun_config: TunConfig,
) -> Result<
(
impl Stream<Item = io::Result<PacketBuffer>>,
impl Sink<PacketBuffer, Error = impl std::error::Error> + Clone,
),
Box<dyn std::error::Error>,
> {
let name = "tun0";
let mut tun = create_tun_interface(name, tun_config.tun_fd)?;
let (tun_sink, mut sink_receiver) = mpsc::channel::<PacketBuffer>(1000);
let (tun_stream, stream_receiver) = mpsc::unbounded_channel();
// Spawn a single task to manage the TUN interface
tokio::spawn(async move {
let mut buf_hold = None;
loop {
let mut buf = if let Some(buf) = buf_hold.take() {
buf
} else {
PacketBuffer::new()
};
select! {
data = sink_receiver.recv() => {
match data {
None => return,
Some(data) => {
if let Err(e) = tun.write(&data).await {
error!("Failed to send data to tun interface {e}");
}
}
}
// Save the buffer as we didn't use it
buf_hold = Some(buf);
}
read_result = tun.read(buf.buffer_mut()) => {
let rr = read_result.map(|n| {
buf.set_size(n);
buf
});
if tun_stream.send(rr).is_err() {
error!("Could not forward data to tun stream, receiver is gone");
break;
};
}
}
}
info!("Stop reading from / writing to tun interface");
});
Ok((
tokio_stream::wrappers::UnboundedReceiverStream::new(stream_receiver),
tokio_util::sync::PollSender::new(tun_sink),
))
}
/// Create a new TUN interface
fn create_tun_interface(
name: &str,
tun_fd: i32,
) -> Result<tun::AsyncDevice, Box<dyn std::error::Error>> {
let mut config = tun::Configuration::default();
config
.name(name)
.layer(tun::Layer::L3)
.mtu(LINK_MTU)
.queues(1)
.raw_fd(tun_fd)
.up();
info!("create_tun_interface");
let tun = match tun::create_as_async(&config) {
Ok(tun) => tun,
Err(err) => {
error!("[android]failed to create tun interface: {err}");
return Err(Box::new(err));
}
};
Ok(tun)
}

347
mycelium/src/tun/darwin.rs Normal file
View File

@@ -0,0 +1,347 @@
//! macos specific tun interface setup.
use std::{
ffi::CString,
io::{self, IoSlice},
net::IpAddr,
os::fd::AsRawFd,
str::FromStr,
};
use futures::{Sink, Stream};
use nix::sys::socket::SockaddrIn6;
use tokio::{
io::{AsyncReadExt, AsyncWriteExt},
select,
sync::mpsc,
};
use tracing::{debug, error, info, warn};
use crate::crypto::PacketBuffer;
use crate::subnet::Subnet;
use crate::tun::TunConfig;
// TODO
const LINK_MTU: i32 = 1400;
/// The 4 byte packet header written before a packet is sent on the TUN
// TODO: figure out structure and values, but for now this seems to work.
const HEADER: [u8; 4] = [0, 0, 0, 30];
const IN6_IFF_NODAD: u32 = 0x0020; // netinet6/in6_var.h
const IN6_IFF_SECURED: u32 = 0x0400; // netinet6/in6_var.h
const ND6_INFINITE_LIFETIME: u32 = 0xFFFFFFFF; // netinet6/nd6.h
/// Wrapper for an OS-specific interface name
// Allways hold the max size of an interface. This includes the 0 byte for termination.
// repr transparent so this can be used with libc calls.
#[repr(transparent)]
#[derive(Clone, Copy)]
pub struct IfaceName([libc::c_char; libc::IFNAMSIZ as _]);
/// Wrapped interface handle.
#[derive(Clone, Copy)]
struct Iface {
/// Name of the interface
iface_name: IfaceName,
}
/// Struct to add IPv6 route to interface
#[repr(C)]
pub struct IfaliasReq {
ifname: IfaceName,
addr: SockaddrIn6,
dst_addr: SockaddrIn6,
mask: SockaddrIn6,
flags: u32,
lifetime: AddressLifetime,
}
#[repr(C)]
pub struct AddressLifetime {
/// Not used for userspace -> kernel space
expire: libc::time_t,
/// Not used for userspace -> kernel space
preferred: libc::time_t,
vltime: u32,
pltime: u32,
}
/// Create a new tun interface and set required routes
///
/// # Panics
///
/// This function will panic if called outside of the context of a tokio runtime.
pub async fn new(
tun_config: TunConfig,
) -> Result<
(
impl Stream<Item = io::Result<PacketBuffer>>,
impl Sink<PacketBuffer, Error = impl std::error::Error> + Clone,
),
Box<dyn std::error::Error>,
> {
let tun_name = find_available_utun_name(&tun_config.name)?;
let mut tun = match create_tun_interface(&tun_name) {
Ok(tun) => tun,
Err(e) => {
error!(tun_name=%tun_name, err=%e, "Could not create TUN device. Make sure the name is not yet in use, and you have sufficient privileges to create a network device");
return Err(e);
}
};
let iface = Iface::by_name(&tun_name)?;
iface.add_address(tun_config.node_subnet, tun_config.route_subnet)?;
let (tun_sink, mut sink_receiver) = mpsc::channel::<PacketBuffer>(1000);
let (tun_stream, stream_receiver) = mpsc::unbounded_channel();
// Spawn a single task to manage the TUN interface
tokio::spawn(async move {
let mut buf_hold = None;
loop {
let mut buf: PacketBuffer = buf_hold.take().unwrap_or_default();
select! {
data = sink_receiver.recv() => {
match data {
None => return,
Some(data) => {
// We need to append a 4 byte header here
if let Err(e) = tun.write_vectored(&[IoSlice::new(&HEADER), IoSlice::new(&data)]).await {
error!("Failed to send data to tun interface {e}");
}
}
}
// Save the buffer as we didn't use it
buf_hold = Some(buf);
}
read_result = tun.read(buf.buffer_mut()) => {
let rr = read_result.map(|n| {
buf.set_size(n);
// Trim header
buf.buffer_mut().copy_within(4.., 0);
buf.set_size(n-4);
buf
});
if tun_stream.send(rr).is_err() {
error!("Could not forward data to tun stream, receiver is gone");
break;
};
}
}
}
info!("Stop reading from / writing to tun interface");
});
Ok((
tokio_stream::wrappers::UnboundedReceiverStream::new(stream_receiver),
tokio_util::sync::PollSender::new(tun_sink),
))
}
/// Checks if a name is valid for a utun interface
///
/// Rules:
/// - must start with "utun"
/// - followed by only digits
/// - 15 chars total at most
fn validate_utun_name(input: &str) -> bool {
if input.len() > 15 {
return false;
}
if !input.starts_with("utun") {
return false;
}
input
.strip_prefix("utun")
.expect("We just checked that name starts with 'utun' so this is always some")
.parse::<u64>()
.is_ok()
}
/// Validates the user-supplied TUN interface name
///
/// - If the name is valid and not in use, it will be the TUN name
/// - If the name is valid but already in use, an error will be thrown
/// - If the name is not valid, we try to find the first freely available TUN name
fn find_available_utun_name(preferred_name: &str) -> Result<String, io::Error> {
// Get the list of existing utun interfaces.
let interfaces = netdev::get_interfaces();
let utun_interfaces: Vec<_> = interfaces
.iter()
.filter_map(|iface| {
if iface.name.starts_with("utun") {
Some(iface.name.as_str())
} else {
None
}
})
.collect();
// Check if the preferred name is valid and not in use.
if validate_utun_name(preferred_name) && !utun_interfaces.contains(&preferred_name) {
return Ok(preferred_name.to_string());
}
// If the preferred name is invalid or already in use, find the first available utun name.
if !validate_utun_name(preferred_name) {
warn!(tun_name=%preferred_name, "Invalid TUN name. Looking for the first available TUN name");
} else {
warn!(tun_name=%preferred_name, "TUN name already in use. Looking for the next available TUN name.");
}
// Extract and sort the utun numbers.
let mut utun_numbers = utun_interfaces
.iter()
.filter_map(|iface| iface[4..].parse::<usize>().ok())
.collect::<Vec<_>>();
utun_numbers.sort_unstable();
// Find the first available utun index.
let mut first_free_index = 0;
for (i, &num) in utun_numbers.iter().enumerate() {
if num != i {
first_free_index = i;
break;
}
first_free_index = i + 1;
}
// Create new utun name based on the first free index.
let new_utun_name = format!("utun{}", first_free_index);
if validate_utun_name(&new_utun_name) {
info!(tun_name=%new_utun_name, "Automatically assigned TUN name.");
Ok(new_utun_name)
} else {
error!("No available TUN name found");
Err(io::Error::new(
io::ErrorKind::Other,
"No available TUN name",
))
}
}
/// Create a new TUN interface
fn create_tun_interface(name: &str) -> Result<tun::AsyncDevice, Box<dyn std::error::Error>> {
let mut config = tun::Configuration::default();
config
.name(name)
.layer(tun::Layer::L3)
.mtu(LINK_MTU)
.queues(1)
.up();
let tun = tun::create_as_async(&config)?;
Ok(tun)
}
impl IfaceName {
fn as_ptr(&self) -> *const libc::c_char {
self.0.as_ptr()
}
}
impl FromStr for IfaceName {
type Err = &'static str;
fn from_str(s: &str) -> Result<Self, Self::Err> {
// Equal len is not allowed because we need to add the 0 byte terminator.
if s.len() >= libc::IFNAMSIZ {
return Err("Interface name too long");
}
// TODO: Is this err possible in a &str?
let raw_name = CString::new(s).map_err(|_| "Interface name contains 0 byte")?;
let mut backing = [0; libc::IFNAMSIZ];
let name_bytes = raw_name.to_bytes_with_nul();
backing[..name_bytes.len()].copy_from_slice(name_bytes);
// SAFETY: This doesn't do any weird things with the bits when converting from u8 to i8
let backing = unsafe { std::mem::transmute::<[u8; 16], [i8; 16]>(backing) };
Ok(Self(backing))
}
}
impl Iface {
/// Retrieve the link index of an interface with the given name
fn by_name(name: &str) -> Result<Iface, Box<dyn std::error::Error>> {
let iface_name: IfaceName = name.parse()?;
match unsafe { libc::if_nametoindex(iface_name.as_ptr()) } {
0 => Err(std::io::Error::new(
std::io::ErrorKind::NotFound,
"interface not found",
))?,
_ => Ok(Iface { iface_name }),
}
}
/// Add an address to an interface.
///
/// # Panics
///
/// Only IPv6 is supported, this function will panic when adding an IPv4 subnet.
fn add_address(
&self,
subnet: Subnet,
route_subnet: Subnet,
) -> Result<(), Box<dyn std::error::Error>> {
let addr = if let IpAddr::V6(addr) = subnet.address() {
addr
} else {
panic!("IPv4 subnets are not supported");
};
let mask_addr = if let IpAddr::V6(mask) = route_subnet.mask() {
mask
} else {
// We already know we are IPv6 here
panic!("IPv4 routes are not supported");
};
let sock_addr = SockaddrIn6::from(std::net::SocketAddrV6::new(addr, 0, 0, 0));
let mask = SockaddrIn6::from(std::net::SocketAddrV6::new(mask_addr, 0, 0, 0));
let req = IfaliasReq {
ifname: self.iface_name,
addr: sock_addr,
// SAFETY: kernel expects this to be fully zeroed
dst_addr: unsafe { std::mem::zeroed() },
mask,
flags: IN6_IFF_NODAD | IN6_IFF_SECURED,
lifetime: AddressLifetime {
expire: 0,
preferred: 0,
vltime: ND6_INFINITE_LIFETIME,
pltime: ND6_INFINITE_LIFETIME,
},
};
let sock = random_socket()?;
match unsafe { siocaifaddr_in6(sock.as_raw_fd(), &req) } {
Err(e) => {
error!("Failed to add ipv6 addresst to interface {e}");
Err(std::io::Error::last_os_error())?
}
Ok(_) => {
debug!("Added {subnet} to tun interfacel");
Ok(())
}
}
}
}
// Create a socket to talk to the kernel.
fn random_socket() -> Result<std::net::UdpSocket, std::io::Error> {
std::net::UdpSocket::bind("[::1]:0")
}
nix::ioctl_write_ptr!(
/// Add an IPv6 subnet to an interface.
siocaifaddr_in6,
b'i',
26,
IfaliasReq
);

100
mycelium/src/tun/ios.rs Normal file
View File

@@ -0,0 +1,100 @@
//! ios specific tun interface setup.
use std::io::{self, IoSlice};
use futures::{Sink, Stream};
use tokio::{
io::{AsyncReadExt, AsyncWriteExt},
select,
sync::mpsc,
};
use tracing::{error, info};
use crate::crypto::PacketBuffer;
use crate::tun::TunConfig;
// TODO
const LINK_MTU: i32 = 1400;
/// The 4 byte packet header written before a packet is sent on the TUN
// TODO: figure out structure and values, but for now this seems to work.
const HEADER: [u8; 4] = [0, 0, 0, 30];
/// Create a new tun interface and set required routes
///
/// # Panics
///
/// This function will panic if called outside of the context of a tokio runtime.
pub async fn new(
tun_config: TunConfig,
) -> Result<
(
impl Stream<Item = io::Result<PacketBuffer>>,
impl Sink<PacketBuffer, Error = impl std::error::Error> + Clone,
),
Box<dyn std::error::Error>,
> {
let mut tun = create_tun_interface(tun_config.tun_fd)?;
let (tun_sink, mut sink_receiver) = mpsc::channel::<PacketBuffer>(1000);
let (tun_stream, stream_receiver) = mpsc::unbounded_channel();
// Spawn a single task to manage the TUN interface
tokio::spawn(async move {
let mut buf_hold = None;
loop {
let mut buf: PacketBuffer = buf_hold.take().unwrap_or_default();
select! {
data = sink_receiver.recv() => {
match data {
None => return,
Some(data) => {
// We need to append a 4 byte header here
if let Err(e) = tun.write_vectored(&[IoSlice::new(&HEADER), IoSlice::new(&data)]).await {
error!("Failed to send data to tun interface {e}");
}
}
}
// Save the buffer as we didn't use it
buf_hold = Some(buf);
}
read_result = tun.read(buf.buffer_mut()) => {
let rr = read_result.map(|n| {
buf.set_size(n);
// Trim header
buf.buffer_mut().copy_within(4.., 0);
buf.set_size(n-4);
buf
});
if tun_stream.send(rr).is_err() {
error!("Could not forward data to tun stream, receiver is gone");
break;
};
}
}
}
info!("Stop reading from / writing to tun interface");
});
Ok((
tokio_stream::wrappers::UnboundedReceiverStream::new(stream_receiver),
tokio_util::sync::PollSender::new(tun_sink),
))
}
/// Create a new TUN interface
fn create_tun_interface(tun_fd: i32) -> Result<tun::AsyncDevice, Box<dyn std::error::Error>> {
let mut config = tun::Configuration::default();
config
.layer(tun::Layer::L3)
.mtu(LINK_MTU)
.queues(1)
.raw_fd(tun_fd)
.up();
let tun = tun::create_as_async(&config)?;
Ok(tun)
}

156
mycelium/src/tun/linux.rs Normal file
View File

@@ -0,0 +1,156 @@
//! Linux specific tun interface setup.
use std::io;
use futures::{Sink, Stream, TryStreamExt};
use rtnetlink::Handle;
use tokio::{select, sync::mpsc};
use tokio_tun::{Tun, TunBuilder};
use tracing::{error, info};
use crate::crypto::PacketBuffer;
use crate::subnet::Subnet;
use crate::tun::TunConfig;
// TODO
const LINK_MTU: i32 = 1400;
/// Create a new tun interface and set required routes
///
/// # Panics
///
/// This function will panic if called outside of the context of a tokio runtime.
pub async fn new(
tun_config: TunConfig,
) -> Result<
(
impl Stream<Item = io::Result<PacketBuffer>>,
impl Sink<PacketBuffer, Error = impl std::error::Error> + Clone,
),
Box<dyn std::error::Error>,
> {
let tun = match create_tun_interface(&tun_config.name) {
Ok(tun) => tun,
Err(e) => {
error!(
"Could not create tun device named \"{}\", make sure the name is not yet in use, and you have sufficient privileges to create a network device",
tun_config.name,
);
return Err(e);
}
};
let (conn, handle, _) = rtnetlink::new_connection()?;
let netlink_task_handle = tokio::spawn(conn);
let tun_index = link_index_by_name(handle.clone(), tun_config.name).await?;
if let Err(e) = add_address(
handle.clone(),
tun_index,
Subnet::new(
tun_config.node_subnet.address(),
tun_config.route_subnet.prefix_len(),
)
.unwrap(),
)
.await
{
error!(
"Failed to add address {0} to TUN interface: {e}",
tun_config.node_subnet
);
return Err(e);
}
// We are done with our netlink connection, abort the task so we can properly clean up.
netlink_task_handle.abort();
let (tun_sink, mut sink_receiver) = mpsc::channel::<PacketBuffer>(1000);
let (tun_stream, stream_receiver) = mpsc::unbounded_channel();
// Spawn a single task to manage the TUN interface
tokio::spawn(async move {
let mut buf_hold = None;
loop {
let mut buf: PacketBuffer = buf_hold.take().unwrap_or_default();
select! {
data = sink_receiver.recv() => {
match data {
None => return,
Some(data) => {
if let Err(e) = tun.send(&data).await {
error!("Failed to send data to tun interface {e}");
}
}
}
// Save the buffer as we didn't use it
buf_hold = Some(buf);
}
read_result = tun.recv(buf.buffer_mut()) => {
let rr = read_result.map(|n| {
buf.set_size(n);
buf
});
if tun_stream.send(rr).is_err() {
error!("Could not forward data to tun stream, receiver is gone");
break;
};
}
}
}
info!("Stop reading from / writing to tun interface");
});
Ok((
tokio_stream::wrappers::UnboundedReceiverStream::new(stream_receiver),
tokio_util::sync::PollSender::new(tun_sink),
))
}
/// Create a new TUN interface
fn create_tun_interface(name: &str) -> Result<Tun, Box<dyn std::error::Error>> {
let tun = TunBuilder::new()
.name(name)
.mtu(LINK_MTU)
.queues(1)
.up()
.build()?
.pop()
.expect("Succesfully build tun interface has 1 queue");
Ok(tun)
}
/// Retrieve the link index of an interface with the given name
async fn link_index_by_name(
handle: Handle,
name: String,
) -> Result<u32, Box<dyn std::error::Error>> {
handle
.link()
.get()
.match_name(name)
.execute()
.try_next()
.await?
.map(|link_message| link_message.header.index)
.ok_or(io::Error::new(io::ErrorKind::NotFound, "link not found").into())
}
/// Add an address to an interface.
///
/// The kernel will automatically add a route entry for the subnet assigned to the interface.
async fn add_address(
handle: Handle,
link_index: u32,
subnet: Subnet,
) -> Result<(), Box<dyn std::error::Error>> {
Ok(handle
.address()
.add(link_index, subnet.address(), subnet.prefix_len())
.execute()
.await?)
}

166
mycelium/src/tun/windows.rs Normal file
View File

@@ -0,0 +1,166 @@
use std::{io, ops::Deref, sync::Arc};
use futures::{Sink, Stream};
use tokio::sync::mpsc;
use tracing::{error, info, warn};
use crate::tun::TunConfig;
use crate::{crypto::PacketBuffer, subnet::Subnet};
// TODO
const LINK_MTU: usize = 1400;
/// Type of the tunnel used, specified when creating the tunnel.
const WINDOWS_TUNNEL_TYPE: &str = "Mycelium";
pub async fn new(
tun_config: TunConfig,
) -> Result<
(
impl Stream<Item = io::Result<PacketBuffer>>,
impl Sink<PacketBuffer, Error = impl std::error::Error> + Clone,
),
Box<dyn std::error::Error>,
> {
// SAFETY: for now we assume a valid wintun.dll file exists in the root directory when we are
// running this.
let wintun = unsafe { wintun::load() }?;
let wintun_version = match wintun::get_running_driver_version(&wintun) {
Ok(v) => format!("{v}"),
Err(e) => {
warn!("Failed to read wintun.dll version: {e}");
"Unknown".to_string()
}
};
info!("Loaded wintun.dll - running version {wintun_version}");
let tun = wintun::Adapter::create(&wintun, &tun_config.name, WINDOWS_TUNNEL_TYPE, None)?;
info!("Created wintun tunnel interface");
// Configure created network adapter.
set_adapter_mtu(&tun_config.name, LINK_MTU)?;
// Set address, this will use a `netsh` command under the hood unfortunately.
// TODO: fix in library
// tun.set_network_addresses_tuple(node_subnet.address(), route_subnet.mask(), None)?;
add_address(
&tun_config.name,
tun_config.node_subnet,
tun_config.route_subnet,
)?;
// Build 2 separate sessions - one for receiving, one for sending.
let rx_session = Arc::new(tun.start_session(wintun::MAX_RING_CAPACITY)?);
let tx_session = rx_session.clone();
let (tun_sink, mut sink_receiver) = mpsc::channel::<PacketBuffer>(1000);
let (tun_stream, stream_receiver) = mpsc::unbounded_channel();
// Ingress path
tokio::task::spawn_blocking(move || {
loop {
let packet = rx_session
.receive_blocking()
.map(|tun_packet| {
let mut buffer = PacketBuffer::new();
// SAFETY: The configured MTU is smaller than the static PacketBuffer size.
let packet_len = tun_packet.bytes().len();
buffer.buffer_mut()[..packet_len].copy_from_slice(tun_packet.bytes());
buffer.set_size(packet_len);
buffer
})
.map_err(wintun_to_io_error);
if tun_stream.send(packet).is_err() {
error!("Could not forward data to tun stream, receiver is gone");
break;
};
}
info!("Stop reading from tun interface");
});
// Egress path
tokio::task::spawn_blocking(move || {
loop {
match sink_receiver.blocking_recv() {
None => break,
Some(data) => {
let mut tun_packet =
match tx_session.allocate_send_packet(data.deref().len() as u16) {
Ok(tun_packet) => tun_packet,
Err(e) => {
error!("Could not allocate packet on TUN: {e}");
break;
}
};
// SAFETY: packet allocation is done on the length of &data.
tun_packet.bytes_mut().copy_from_slice(&data);
tx_session.send_packet(tun_packet);
}
}
}
info!("Stop writing to tun interface");
});
Ok((
tokio_stream::wrappers::UnboundedReceiverStream::new(stream_receiver),
tokio_util::sync::PollSender::new(tun_sink),
))
}
/// Helper method to convert a [`wintun::Error`] to a [`std::io::Error`].
fn wintun_to_io_error(err: wintun::Error) -> io::Error {
match err {
wintun::Error::Io(e) => e,
_ => io::Error::other("unknown wintun error"),
}
}
/// Set an address on an interface by shelling out to `netsh`
///
/// We assume this is an IPv6 address.
fn add_address(adapter_name: &str, subnet: Subnet, route_subnet: Subnet) -> Result<(), io::Error> {
let exit_code = std::process::Command::new("netsh")
.args([
"interface",
"ipv6",
"set",
"address",
adapter_name,
&format!("{}/{}", subnet.address(), route_subnet.prefix_len()),
])
.spawn()?
.wait()?;
match exit_code.code() {
Some(0) => Ok(()),
Some(x) => Err(io::Error::from_raw_os_error(x)),
None => {
warn!("Failed to determine `netsh` exit status");
Ok(())
}
}
}
fn set_adapter_mtu(name: &str, mtu: usize) -> Result<(), io::Error> {
let args = &[
"interface",
"ipv6",
"set",
"subinterface",
&format!("\"{name}\""),
&format!("mtu={mtu}"),
"store=persistent",
];
let exit_code = std::process::Command::new("netsh")
.args(args)
.spawn()?
.wait()?;
match exit_code.code() {
Some(0) => Ok(()),
Some(x) => Err(io::Error::from_raw_os_error(x)),
None => {
warn!("Failed to determine `netsh` exit status");
Ok(())
}
}
}