Squashed 'components/mycelium/' content from commit afb32e0
git-subtree-dir: components/mycelium git-subtree-split: afb32e0cdb2d4cdd17f22a5693278068d061f08c
This commit is contained in:
79
mycelium/Cargo.toml
Normal file
79
mycelium/Cargo.toml
Normal file
@@ -0,0 +1,79 @@
|
||||
[package]
|
||||
name = "mycelium"
|
||||
version = "0.6.1"
|
||||
edition = "2021"
|
||||
license-file = "../LICENSE"
|
||||
readme = "../README.md"
|
||||
|
||||
[features]
|
||||
message = []
|
||||
private-network = ["dep:openssl", "dep:tokio-openssl"]
|
||||
vendored-openssl = ["openssl/vendored"]
|
||||
mactunfd = [
|
||||
"tun/appstore",
|
||||
] #mactunfd is a flag to specify that macos should provide tun FD instead of tun name
|
||||
|
||||
[dependencies]
|
||||
cdn-meta = { git = "https://github.com/threefoldtech/mycelium-cdn-registry", package = "cdn-meta" }
|
||||
|
||||
tokio = { version = "1.46.1", features = [
|
||||
"io-util",
|
||||
"fs",
|
||||
"macros",
|
||||
"net",
|
||||
"sync",
|
||||
"time",
|
||||
"rt-multi-thread", # FIXME: remove once tokio::task::block_in_place calls are resolved
|
||||
] }
|
||||
tokio-util = { version = "0.7.15", features = ["codec"] }
|
||||
futures = "0.3.31"
|
||||
serde = { version = "1.0.219", features = ["derive"] }
|
||||
rand = "0.9.1"
|
||||
bytes = "1.10.1"
|
||||
x25519-dalek = { version = "2.0.1", features = ["getrandom", "static_secrets"] }
|
||||
aes-gcm = "0.10.3"
|
||||
tracing = { version = "0.1.41", features = ["release_max_level_debug"] }
|
||||
tracing-subscriber = { version = "0.3.19", features = ["env-filter"] }
|
||||
tracing-logfmt = { version = "0.3.5", features = ["ansi_logs"] }
|
||||
faster-hex = "0.10.0"
|
||||
tokio-stream = { version = "0.1.17", features = ["sync"] }
|
||||
left-right = "0.11.5"
|
||||
ipnet = "2.11.0"
|
||||
ip_network_table-deps-treebitmap = "0.5.0"
|
||||
blake3 = "1.8.2"
|
||||
etherparse = "0.18.0"
|
||||
quinn = { version = "0.11.8", default-features = false, features = [
|
||||
"runtime-tokio",
|
||||
"rustls",
|
||||
] }
|
||||
rustls = { version = "0.23.29", default-features = false, features = ["ring"] }
|
||||
rcgen = "0.14.2"
|
||||
netdev = "0.36.0"
|
||||
openssl = { version = "0.10.73", optional = true }
|
||||
tokio-openssl = { version = "0.6.5", optional = true }
|
||||
arc-swap = "1.7.1"
|
||||
dashmap = { version = "6.1.0", features = ["inline"] }
|
||||
ahash = "0.8.11"
|
||||
axum = "0.8.4"
|
||||
axum-extra = "0.10.1"
|
||||
reqwest = "0.12.22"
|
||||
redis = { version = "0.32.4", features = ["tokio-comp"] }
|
||||
reed-solomon-erasure = "6.0.0"
|
||||
[target.'cfg(target_os = "linux")'.dependencies]
|
||||
rtnetlink = "0.17.0"
|
||||
tokio-tun = "0.13.2"
|
||||
nix = { version = "0.30.1", features = ["socket"] }
|
||||
|
||||
[target.'cfg(target_os = "macos")'.dependencies]
|
||||
tun = { git = "https://github.com/LeeSmet/rust-tun", features = ["async"] }
|
||||
libc = "0.2.174"
|
||||
nix = { version = "0.29.0", features = ["net", "socket", "ioctl"] }
|
||||
|
||||
[target.'cfg(target_os = "windows")'.dependencies]
|
||||
wintun = "0.5.1"
|
||||
|
||||
[target.'cfg(target_os = "android")'.dependencies]
|
||||
tun = { git = "https://github.com/LeeSmet/rust-tun", features = ["async"] }
|
||||
|
||||
[target.'cfg(target_os = "ios")'.dependencies]
|
||||
tun = { git = "https://github.com/LeeSmet/rust-tun", features = ["async"] }
|
||||
321
mycelium/src/babel.rs
Normal file
321
mycelium/src/babel.rs
Normal file
@@ -0,0 +1,321 @@
|
||||
//! This module contains babel related structs.
|
||||
//!
|
||||
//! We don't fully implement the babel spec, and items which are implemented might deviate to fit
|
||||
//! our specific use case. For reference, the implementation is based on [this
|
||||
//! RFC](https://datatracker.ietf.org/doc/html/rfc8966).
|
||||
|
||||
use std::io;
|
||||
|
||||
use bytes::{Buf, BufMut};
|
||||
use tokio_util::codec::{Decoder, Encoder};
|
||||
use tracing::trace;
|
||||
|
||||
pub use self::{
|
||||
hello::Hello, ihu::Ihu, route_request::RouteRequest, seqno_request::SeqNoRequest,
|
||||
update::Update,
|
||||
};
|
||||
|
||||
pub use self::tlv::Tlv;
|
||||
|
||||
mod hello;
|
||||
mod ihu;
|
||||
mod route_request;
|
||||
mod seqno_request;
|
||||
mod tlv;
|
||||
mod update;
|
||||
|
||||
/// Magic byte to identify babel protocol packet.
|
||||
const BABEL_MAGIC: u8 = 42;
|
||||
/// The version of the protocol we are currently using.
|
||||
const BABEL_VERSION: u8 = 3;
|
||||
|
||||
/// Size of a babel header on the wire.
|
||||
const HEADER_WIRE_SIZE: usize = 4;
|
||||
|
||||
/// TLV type for the [`Hello`] tlv
|
||||
const TLV_TYPE_HELLO: u8 = 4;
|
||||
/// TLV type for the [`Ihu`] tlv
|
||||
const TLV_TYPE_IHU: u8 = 5;
|
||||
/// TLV type for the [`Update`] tlv
|
||||
const TLV_TYPE_UPDATE: u8 = 8;
|
||||
/// TLV type for the [`RouteRequest`] tlv
|
||||
const TLV_TYPE_ROUTE_REQUEST: u8 = 9;
|
||||
/// TLV type for the [`SeqNoRequest`] tlv
|
||||
const TLV_TYPE_SEQNO_REQUEST: u8 = 10;
|
||||
|
||||
/// Wildcard address, the value is empty (0 bytes length).
|
||||
const AE_WILDCARD: u8 = 0;
|
||||
/// IPv4 address, the value is _at most_ 4 bytes long.
|
||||
const AE_IPV4: u8 = 1;
|
||||
/// IPv6 address, the value is _at most_ 16 bytes long.
|
||||
const AE_IPV6: u8 = 2;
|
||||
/// Link-local IPv6 address, the value is 8 bytes long. This implies a `fe80::/64` prefix.
|
||||
const AE_IPV6_LL: u8 = 3;
|
||||
|
||||
/// A codec which can send and receive whole babel packets on the wire.
|
||||
#[derive(Debug, Clone)]
|
||||
pub struct Codec {
|
||||
header: Option<Header>,
|
||||
}
|
||||
|
||||
impl Codec {
|
||||
/// Create a new `BabelCodec`.
|
||||
pub fn new() -> Self {
|
||||
Self { header: None }
|
||||
}
|
||||
|
||||
/// Resets the `BabelCodec` to its default state.
|
||||
pub fn reset(&mut self) {
|
||||
self.header = None;
|
||||
}
|
||||
}
|
||||
|
||||
/// The header for a babel packet. This follows the definition of the header [in the
|
||||
/// RFC](https://datatracker.ietf.org/doc/html/rfc8966#name-packet-format). Since the header
|
||||
/// contains only hard-coded fields and the length of an encoded body, there is no need for users
|
||||
/// to manually construct this. In fact, it exists only to make our lives slightly easier in
|
||||
/// reading/writing the header on the wire.
|
||||
#[derive(Debug, Clone)]
|
||||
struct Header {
|
||||
magic: u8,
|
||||
version: u8,
|
||||
/// This is the length of the whole body following this header. Also excludes any possible
|
||||
/// trailers.
|
||||
body_length: u16,
|
||||
}
|
||||
|
||||
impl Decoder for Codec {
|
||||
type Item = Tlv;
|
||||
|
||||
type Error = io::Error;
|
||||
|
||||
fn decode(&mut self, src: &mut bytes::BytesMut) -> Result<Option<Self::Item>, Self::Error> {
|
||||
// Read a header if we don't have one yet.
|
||||
let header = if let Some(header) = self.header.take() {
|
||||
trace!("Continue from stored header");
|
||||
header
|
||||
} else {
|
||||
if src.remaining() < HEADER_WIRE_SIZE {
|
||||
trace!("Insufficient bytes to read a babel header");
|
||||
return Ok(None);
|
||||
}
|
||||
|
||||
trace!("Read babel header");
|
||||
|
||||
Header {
|
||||
magic: src.get_u8(),
|
||||
version: src.get_u8(),
|
||||
body_length: src.get_u16(),
|
||||
}
|
||||
};
|
||||
|
||||
if src.remaining() < header.body_length as usize {
|
||||
trace!("Insufficient bytes to read babel body");
|
||||
self.header = Some(header);
|
||||
return Ok(None);
|
||||
}
|
||||
|
||||
// Siltently ignore packets which don't have the correct values set, as defined in the
|
||||
// spec. Note that we consume the amount of bytes indentified so we leave the parser in the
|
||||
// correct state for the next packet.
|
||||
if header.magic != BABEL_MAGIC || header.version != BABEL_VERSION {
|
||||
trace!("Dropping babel packet with wrong magic or version");
|
||||
src.advance(header.body_length as usize);
|
||||
self.reset();
|
||||
return Ok(None);
|
||||
}
|
||||
|
||||
// at this point we have a whole body loaded in the buffer. We currently don't support sub
|
||||
// TLV's
|
||||
|
||||
trace!("Read babel TLV body");
|
||||
|
||||
// TODO: Technically we need to loop here as we can have multiple TLVs.
|
||||
|
||||
// TLV header
|
||||
let tlv_type = src.get_u8();
|
||||
let body_len = src.get_u8();
|
||||
// TLV payload
|
||||
let tlv = match tlv_type {
|
||||
TLV_TYPE_HELLO => Some(Hello::from_bytes(src).into()),
|
||||
TLV_TYPE_IHU => Ihu::from_bytes(src, body_len).map(From::from),
|
||||
TLV_TYPE_UPDATE => Update::from_bytes(src, body_len).map(From::from),
|
||||
TLV_TYPE_ROUTE_REQUEST => RouteRequest::from_bytes(src, body_len).map(From::from),
|
||||
TLV_TYPE_SEQNO_REQUEST => SeqNoRequest::from_bytes(src, body_len).map(From::from),
|
||||
_ => {
|
||||
// unrecoginized body type, silently drop
|
||||
trace!("Dropping unrecognized tlv");
|
||||
// We already read 2 bytes
|
||||
src.advance(header.body_length as usize - 2);
|
||||
self.reset();
|
||||
return Ok(None);
|
||||
}
|
||||
};
|
||||
|
||||
Ok(tlv)
|
||||
}
|
||||
}
|
||||
|
||||
impl Encoder<Tlv> for Codec {
|
||||
type Error = io::Error;
|
||||
|
||||
fn encode(&mut self, item: Tlv, dst: &mut bytes::BytesMut) -> Result<(), Self::Error> {
|
||||
// Write header
|
||||
dst.put_u8(BABEL_MAGIC);
|
||||
dst.put_u8(BABEL_VERSION);
|
||||
dst.put_u16(item.wire_size() as u16 + 2); // tlv payload + tlv header
|
||||
|
||||
// Write TLV's, TODO: currently only 1 TLV/body
|
||||
|
||||
// TLV header
|
||||
match item {
|
||||
Tlv::Hello(_) => dst.put_u8(TLV_TYPE_HELLO),
|
||||
Tlv::Ihu(_) => dst.put_u8(TLV_TYPE_IHU),
|
||||
Tlv::Update(_) => dst.put_u8(TLV_TYPE_UPDATE),
|
||||
Tlv::RouteRequest(_) => dst.put_u8(TLV_TYPE_ROUTE_REQUEST),
|
||||
Tlv::SeqNoRequest(_) => dst.put_u8(TLV_TYPE_SEQNO_REQUEST),
|
||||
}
|
||||
dst.put_u8(item.wire_size());
|
||||
item.write_bytes(dst);
|
||||
|
||||
Ok(())
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use std::{net::Ipv6Addr, time::Duration};
|
||||
|
||||
use futures::{SinkExt, StreamExt};
|
||||
use tokio_util::codec::Framed;
|
||||
|
||||
use crate::subnet::Subnet;
|
||||
|
||||
#[tokio::test]
|
||||
async fn codec_hello() {
|
||||
let (tx, rx) = tokio::io::duplex(1024);
|
||||
let mut sender = Framed::new(tx, super::Codec::new());
|
||||
let mut receiver = Framed::new(rx, super::Codec::new());
|
||||
|
||||
let hello = super::Hello::new_unicast(15.into(), 400);
|
||||
|
||||
sender
|
||||
.send(hello.clone().into())
|
||||
.await
|
||||
.expect("Send on a non-networked buffer can never fail; qed");
|
||||
let recv_hello = receiver
|
||||
.next()
|
||||
.await
|
||||
.expect("Buffer isn't closed so this is always `Some`; qed")
|
||||
.expect("Can decode the previously encoded value");
|
||||
assert_eq!(super::Tlv::from(hello), recv_hello);
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn codec_ihu() {
|
||||
let (tx, rx) = tokio::io::duplex(1024);
|
||||
let mut sender = Framed::new(tx, super::Codec::new());
|
||||
let mut receiver = Framed::new(rx, super::Codec::new());
|
||||
|
||||
let ihu = super::Ihu::new(27.into(), 400, None);
|
||||
|
||||
sender
|
||||
.send(ihu.clone().into())
|
||||
.await
|
||||
.expect("Send on a non-networked buffer can never fail; qed");
|
||||
let recv_ihu = receiver
|
||||
.next()
|
||||
.await
|
||||
.expect("Buffer isn't closed so this is always `Some`; qed")
|
||||
.expect("Can decode the previously encoded value");
|
||||
assert_eq!(super::Tlv::from(ihu), recv_ihu);
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn codec_update() {
|
||||
let (tx, rx) = tokio::io::duplex(1024);
|
||||
let mut sender = Framed::new(tx, super::Codec::new());
|
||||
let mut receiver = Framed::new(rx, super::Codec::new());
|
||||
|
||||
let update = super::Update::new(
|
||||
Duration::from_secs(400),
|
||||
16.into(),
|
||||
25.into(),
|
||||
Subnet::new(Ipv6Addr::new(0x400, 1, 2, 3, 0, 0, 0, 0).into(), 64)
|
||||
.expect("64 is a valid IPv6 prefix size; qed"),
|
||||
[
|
||||
1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19, 20, 21, 22, 23,
|
||||
24, 25, 26, 27, 28, 29, 30, 31, 32, 33, 34, 35, 36, 37, 38, 39, 40,
|
||||
]
|
||||
.into(),
|
||||
);
|
||||
|
||||
sender
|
||||
.send(update.clone().into())
|
||||
.await
|
||||
.expect("Send on a non-networked buffer can never fail; qed");
|
||||
println!("Sent update packet");
|
||||
let recv_update = receiver
|
||||
.next()
|
||||
.await
|
||||
.expect("Buffer isn't closed so this is always `Some`; qed")
|
||||
.expect("Can decode the previously encoded value");
|
||||
println!("Received update packet");
|
||||
assert_eq!(super::Tlv::from(update), recv_update);
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn codec_seqno_request() {
|
||||
let (tx, rx) = tokio::io::duplex(1024);
|
||||
let mut sender = Framed::new(tx, super::Codec::new());
|
||||
let mut receiver = Framed::new(rx, super::Codec::new());
|
||||
|
||||
let snr = super::SeqNoRequest::new(
|
||||
16.into(),
|
||||
[
|
||||
1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19, 20, 21, 22, 23,
|
||||
24, 25, 26, 27, 28, 29, 30, 31, 32, 33, 34, 35, 36, 37, 38, 39, 40,
|
||||
]
|
||||
.into(),
|
||||
Subnet::new(Ipv6Addr::new(0x400, 1, 2, 3, 0, 0, 0, 0).into(), 64)
|
||||
.expect("64 is a valid IPv6 prefix size; qed"),
|
||||
);
|
||||
|
||||
sender
|
||||
.send(snr.clone().into())
|
||||
.await
|
||||
.expect("Send on a non-networked buffer can never fail; qed");
|
||||
let recv_update = receiver
|
||||
.next()
|
||||
.await
|
||||
.expect("Buffer isn't closed so this is always `Some`; qed")
|
||||
.expect("Can decode the previously encoded value");
|
||||
assert_eq!(super::Tlv::from(snr), recv_update);
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn codec_route_request() {
|
||||
let (tx, rx) = tokio::io::duplex(1024);
|
||||
let mut sender = Framed::new(tx, super::Codec::new());
|
||||
let mut receiver = Framed::new(rx, super::Codec::new());
|
||||
|
||||
let rr = super::RouteRequest::new(
|
||||
Some(
|
||||
Subnet::new(Ipv6Addr::new(0x400, 1, 2, 3, 0, 0, 0, 0).into(), 64)
|
||||
.expect("64 is a valid IPv6 prefix size; qed"),
|
||||
),
|
||||
13,
|
||||
);
|
||||
|
||||
sender
|
||||
.send(rr.clone().into())
|
||||
.await
|
||||
.expect("Send on a non-networked buffer can never fail; qed");
|
||||
let recv_update = receiver
|
||||
.next()
|
||||
.await
|
||||
.expect("Buffer isn't closed so this is always `Some`; qed")
|
||||
.expect("Can decode the previously encoded value");
|
||||
assert_eq!(super::Tlv::from(rr), recv_update);
|
||||
}
|
||||
}
|
||||
162
mycelium/src/babel/hello.rs
Normal file
162
mycelium/src/babel/hello.rs
Normal file
@@ -0,0 +1,162 @@
|
||||
//! The babel [Hello TLV](https://datatracker.ietf.org/doc/html/rfc8966#section-4.6.5).
|
||||
|
||||
use bytes::{Buf, BufMut};
|
||||
use tracing::trace;
|
||||
|
||||
use crate::sequence_number::SeqNo;
|
||||
|
||||
/// Flag bit indicating a [`Hello`] is sent as unicast hello.
|
||||
const HELLO_FLAG_UNICAST: u16 = 0x8000;
|
||||
|
||||
/// Mask to apply to [`Hello`] flags, leaving only valid flags.
|
||||
const FLAG_MASK: u16 = 0b10000000_00000000;
|
||||
|
||||
/// Wire size of a [`Hello`] TLV without TLV header.
|
||||
const HELLO_WIRE_SIZE: u8 = 6;
|
||||
|
||||
/// Hello TLV body as defined in https://datatracker.ietf.org/doc/html/rfc8966#section-4.6.5.
|
||||
#[derive(Debug, Clone, PartialEq)]
|
||||
pub struct Hello {
|
||||
flags: u16,
|
||||
seqno: SeqNo,
|
||||
interval: u16,
|
||||
}
|
||||
|
||||
impl Hello {
|
||||
/// Create a new unicast hello packet.
|
||||
pub fn new_unicast(seqno: SeqNo, interval: u16) -> Self {
|
||||
Self {
|
||||
flags: HELLO_FLAG_UNICAST,
|
||||
seqno,
|
||||
interval,
|
||||
}
|
||||
}
|
||||
|
||||
/// Calculates the size on the wire of this `Hello`.
|
||||
pub fn wire_size(&self) -> u8 {
|
||||
HELLO_WIRE_SIZE
|
||||
}
|
||||
|
||||
/// Construct a `Hello` from wire bytes.
|
||||
///
|
||||
/// # Panics
|
||||
///
|
||||
/// This function will panic if there are insufficient bytes present in the provided buffer to
|
||||
/// decode a complete `Hello`.
|
||||
pub fn from_bytes(src: &mut bytes::BytesMut) -> Self {
|
||||
let flags = src.get_u16() & FLAG_MASK;
|
||||
let seqno = src.get_u16().into();
|
||||
let interval = src.get_u16();
|
||||
|
||||
trace!("Read hello tlv body");
|
||||
|
||||
Self {
|
||||
flags,
|
||||
seqno,
|
||||
interval,
|
||||
}
|
||||
}
|
||||
|
||||
/// Encode this `Hello` tlv as part of a packet.
|
||||
pub fn write_bytes(&self, dst: &mut bytes::BytesMut) {
|
||||
dst.put_u16(self.flags);
|
||||
dst.put_u16(self.seqno.into());
|
||||
dst.put_u16(self.interval);
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use bytes::Buf;
|
||||
|
||||
#[test]
|
||||
fn encoding() {
|
||||
let mut buf = bytes::BytesMut::new();
|
||||
|
||||
let hello = super::Hello {
|
||||
flags: 0,
|
||||
seqno: 25.into(),
|
||||
interval: 400,
|
||||
};
|
||||
|
||||
hello.write_bytes(&mut buf);
|
||||
|
||||
assert_eq!(buf.len(), 6);
|
||||
assert_eq!(buf[..6], [0, 0, 0, 25, 1, 144]);
|
||||
|
||||
let mut buf = bytes::BytesMut::new();
|
||||
|
||||
let hello = super::Hello {
|
||||
flags: super::HELLO_FLAG_UNICAST,
|
||||
seqno: 16.into(),
|
||||
interval: 4000,
|
||||
};
|
||||
|
||||
hello.write_bytes(&mut buf);
|
||||
|
||||
assert_eq!(buf.len(), 6);
|
||||
assert_eq!(buf[..6], [128, 0, 0, 16, 15, 160]);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn decoding() {
|
||||
let mut buf = bytes::BytesMut::from(&[0b10000000u8, 0b00000000, 0, 19, 2, 1][..]);
|
||||
|
||||
let hello = super::Hello {
|
||||
flags: super::HELLO_FLAG_UNICAST,
|
||||
seqno: 19.into(),
|
||||
interval: 513,
|
||||
};
|
||||
|
||||
assert_eq!(super::Hello::from_bytes(&mut buf), hello);
|
||||
assert_eq!(buf.remaining(), 0);
|
||||
|
||||
let mut buf = bytes::BytesMut::from(&[0b00000000u8, 0b00000000, 1, 19, 200, 100][..]);
|
||||
|
||||
let hello = super::Hello {
|
||||
flags: 0,
|
||||
seqno: 275.into(),
|
||||
interval: 51300,
|
||||
};
|
||||
|
||||
assert_eq!(super::Hello::from_bytes(&mut buf), hello);
|
||||
assert_eq!(buf.remaining(), 0);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn decode_ignores_invalid_flag_bits() {
|
||||
let mut buf = bytes::BytesMut::from(&[0b10001001u8, 0b00000000, 0, 100, 1, 144][..]);
|
||||
|
||||
let hello = super::Hello {
|
||||
flags: super::HELLO_FLAG_UNICAST,
|
||||
seqno: 100.into(),
|
||||
interval: 400,
|
||||
};
|
||||
|
||||
assert_eq!(super::Hello::from_bytes(&mut buf), hello);
|
||||
assert_eq!(buf.remaining(), 0);
|
||||
|
||||
let mut buf = bytes::BytesMut::from(&[0b00001001u8, 0b00000000, 0, 100, 1, 144][..]);
|
||||
|
||||
let hello = super::Hello {
|
||||
flags: 0,
|
||||
seqno: 100.into(),
|
||||
interval: 400,
|
||||
};
|
||||
|
||||
assert_eq!(super::Hello::from_bytes(&mut buf), hello);
|
||||
assert_eq!(buf.remaining(), 0);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn roundtrip() {
|
||||
let mut buf = bytes::BytesMut::new();
|
||||
|
||||
let hello_src = super::Hello::new_unicast(16.into(), 400);
|
||||
hello_src.write_bytes(&mut buf);
|
||||
let decoded = super::Hello::from_bytes(&mut buf);
|
||||
|
||||
assert_eq!(hello_src, decoded);
|
||||
assert_eq!(buf.remaining(), 0);
|
||||
}
|
||||
}
|
||||
246
mycelium/src/babel/ihu.rs
Normal file
246
mycelium/src/babel/ihu.rs
Normal file
@@ -0,0 +1,246 @@
|
||||
//! The babel [IHU TLV](https://datatracker.ietf.org/doc/html/rfc8966#name-ihu).
|
||||
|
||||
use std::net::{IpAddr, Ipv4Addr, Ipv6Addr};
|
||||
|
||||
use bytes::{Buf, BufMut};
|
||||
use tracing::trace;
|
||||
|
||||
use crate::metric::Metric;
|
||||
|
||||
use super::{AE_IPV4, AE_IPV6, AE_IPV6_LL, AE_WILDCARD};
|
||||
|
||||
/// Base wire size of an [`Ihu`] without variable length address encoding.
|
||||
const IHU_BASE_WIRE_SIZE: u8 = 6;
|
||||
|
||||
/// IHU TLV body as defined in https://datatracker.ietf.org/doc/html/rfc8966#name-ihu.
|
||||
#[derive(Debug, Clone, PartialEq)]
|
||||
pub struct Ihu {
|
||||
rx_cost: Metric,
|
||||
interval: u16,
|
||||
address: Option<IpAddr>,
|
||||
}
|
||||
|
||||
impl Ihu {
|
||||
/// Create a new `Ihu` to be transmitted.
|
||||
pub fn new(rx_cost: Metric, interval: u16, address: Option<IpAddr>) -> Self {
|
||||
// An interval of 0 is illegal according to the RFC, as this value is used by the receiver
|
||||
// to calculate the hold time.
|
||||
if interval == 0 {
|
||||
panic!("Ihu interval MUST NOT be 0");
|
||||
}
|
||||
Self {
|
||||
rx_cost,
|
||||
interval,
|
||||
address,
|
||||
}
|
||||
}
|
||||
|
||||
/// Calculates the size on the wire of this `Ihu`.
|
||||
pub fn wire_size(&self) -> u8 {
|
||||
IHU_BASE_WIRE_SIZE
|
||||
+ match self.address {
|
||||
None => 0,
|
||||
Some(IpAddr::V4(_)) => 4,
|
||||
// TODO: link local should be encoded differently
|
||||
Some(IpAddr::V6(_)) => 16,
|
||||
}
|
||||
}
|
||||
|
||||
/// Construct a `Ihu` from wire bytes.
|
||||
///
|
||||
/// # Panics
|
||||
///
|
||||
/// This function will panic if there are insufficient bytes present in the provided buffer to
|
||||
/// decode a complete `Ihu`.
|
||||
pub fn from_bytes(src: &mut bytes::BytesMut, len: u8) -> Option<Self> {
|
||||
let ae = src.get_u8();
|
||||
// read and ignore reserved byte
|
||||
let _ = src.get_u8();
|
||||
let rx_cost = src.get_u16().into();
|
||||
let interval = src.get_u16();
|
||||
let address = match ae {
|
||||
AE_WILDCARD => None,
|
||||
AE_IPV4 => {
|
||||
let mut raw_ip = [0; 4];
|
||||
raw_ip.copy_from_slice(&src[..4]);
|
||||
src.advance(4);
|
||||
Some(Ipv4Addr::from(raw_ip).into())
|
||||
}
|
||||
AE_IPV6 => {
|
||||
let mut raw_ip = [0; 16];
|
||||
raw_ip.copy_from_slice(&src[..16]);
|
||||
src.advance(16);
|
||||
Some(Ipv6Addr::from(raw_ip).into())
|
||||
}
|
||||
AE_IPV6_LL => {
|
||||
let mut raw_ip = [0; 16];
|
||||
raw_ip[0] = 0xfe;
|
||||
raw_ip[1] = 0x80;
|
||||
raw_ip[8..].copy_from_slice(&src[..8]);
|
||||
src.advance(8);
|
||||
Some(Ipv6Addr::from(raw_ip).into())
|
||||
}
|
||||
_ => {
|
||||
// Invalid AE type, skip reamining data and ignore
|
||||
trace!("Invalid AE type in IHU TLV, drop TLV");
|
||||
src.advance(len as usize - 6);
|
||||
return None;
|
||||
}
|
||||
};
|
||||
|
||||
trace!("Read ihu tlv body");
|
||||
|
||||
Some(Self {
|
||||
rx_cost,
|
||||
interval,
|
||||
address,
|
||||
})
|
||||
}
|
||||
|
||||
/// Encode this `Ihu` tlv as part of a packet.
|
||||
pub fn write_bytes(&self, dst: &mut bytes::BytesMut) {
|
||||
dst.put_u8(match self.address {
|
||||
None => AE_WILDCARD,
|
||||
Some(IpAddr::V4(_)) => AE_IPV4,
|
||||
Some(IpAddr::V6(_)) => AE_IPV6,
|
||||
});
|
||||
// reserved byte, must be all 0
|
||||
dst.put_u8(0);
|
||||
dst.put_u16(self.rx_cost.into());
|
||||
dst.put_u16(self.interval);
|
||||
match self.address {
|
||||
None => {}
|
||||
Some(IpAddr::V4(ip)) => dst.put_slice(&ip.octets()),
|
||||
Some(IpAddr::V6(ip)) => dst.put_slice(&ip.octets()),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use std::net::{Ipv4Addr, Ipv6Addr};
|
||||
|
||||
use bytes::Buf;
|
||||
|
||||
#[test]
|
||||
fn encoding() {
|
||||
let mut buf = bytes::BytesMut::new();
|
||||
|
||||
let ihu = super::Ihu {
|
||||
rx_cost: 25.into(),
|
||||
interval: 400,
|
||||
address: Some(Ipv4Addr::new(1, 1, 1, 1).into()),
|
||||
};
|
||||
|
||||
ihu.write_bytes(&mut buf);
|
||||
|
||||
assert_eq!(buf.len(), 10);
|
||||
assert_eq!(buf[..10], [1, 0, 0, 25, 1, 144, 1, 1, 1, 1]);
|
||||
|
||||
let mut buf = bytes::BytesMut::new();
|
||||
|
||||
let ihu = super::Ihu {
|
||||
rx_cost: 100.into(),
|
||||
interval: 4000,
|
||||
address: Some(Ipv6Addr::new(2, 0, 1234, 2345, 3456, 4567, 5678, 1).into()),
|
||||
};
|
||||
|
||||
ihu.write_bytes(&mut buf);
|
||||
|
||||
assert_eq!(buf.len(), 22);
|
||||
assert_eq!(
|
||||
buf[..22],
|
||||
[2, 0, 0, 100, 15, 160, 0, 2, 0, 0, 4, 210, 9, 41, 13, 128, 17, 215, 22, 46, 0, 1]
|
||||
);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn decoding() {
|
||||
let mut buf = bytes::BytesMut::from(&[0, 0, 0, 1, 1, 44][..]);
|
||||
|
||||
let ihu = super::Ihu {
|
||||
rx_cost: 1.into(),
|
||||
interval: 300,
|
||||
address: None,
|
||||
};
|
||||
|
||||
let buf_len = buf.len();
|
||||
assert_eq!(super::Ihu::from_bytes(&mut buf, buf_len as u8), Some(ihu));
|
||||
assert_eq!(buf.remaining(), 0);
|
||||
|
||||
let mut buf = bytes::BytesMut::from(&[1, 0, 0, 2, 0, 44, 3, 4, 5, 6][..]);
|
||||
|
||||
let ihu = super::Ihu {
|
||||
rx_cost: 2.into(),
|
||||
interval: 44,
|
||||
address: Some(Ipv4Addr::new(3, 4, 5, 6).into()),
|
||||
};
|
||||
|
||||
let buf_len = buf.len();
|
||||
assert_eq!(super::Ihu::from_bytes(&mut buf, buf_len as u8), Some(ihu));
|
||||
assert_eq!(buf.remaining(), 0);
|
||||
|
||||
let mut buf = bytes::BytesMut::from(
|
||||
&[
|
||||
2, 0, 0, 2, 0, 44, 4, 0, 0, 0, 0, 5, 0, 6, 7, 8, 9, 10, 11, 12, 13, 14,
|
||||
][..],
|
||||
);
|
||||
|
||||
let ihu = super::Ihu {
|
||||
rx_cost: 2.into(),
|
||||
interval: 44,
|
||||
address: Some(Ipv6Addr::new(0x400, 0, 5, 6, 0x708, 0x90a, 0xb0c, 0xd0e).into()),
|
||||
};
|
||||
|
||||
let buf_len = buf.len();
|
||||
assert_eq!(super::Ihu::from_bytes(&mut buf, buf_len as u8), Some(ihu));
|
||||
assert_eq!(buf.remaining(), 0);
|
||||
|
||||
let mut buf = bytes::BytesMut::from(&[3, 0, 1, 2, 0, 42, 7, 8, 9, 10, 11, 12, 13, 14][..]);
|
||||
|
||||
let ihu = super::Ihu {
|
||||
rx_cost: 258.into(),
|
||||
interval: 42,
|
||||
address: Some(Ipv6Addr::new(0xfe80, 0, 0, 0, 0x708, 0x90a, 0xb0c, 0xd0e).into()),
|
||||
};
|
||||
|
||||
let buf_len = buf.len();
|
||||
assert_eq!(super::Ihu::from_bytes(&mut buf, buf_len as u8), Some(ihu));
|
||||
assert_eq!(buf.remaining(), 0);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn decode_ignores_invalid_ae_encoding() {
|
||||
// AE 4 as it is the first one which should be used in protocol extension, causing this
|
||||
// test to fail if we forget to update something
|
||||
let mut buf = bytes::BytesMut::from(
|
||||
&[
|
||||
4, 0, 0, 2, 0, 44, 2, 0, 0, 0, 0, 5, 0, 6, 7, 8, 9, 10, 11, 12, 13, 14,
|
||||
][..],
|
||||
);
|
||||
|
||||
let buf_len = buf.len();
|
||||
|
||||
assert_eq!(super::Ihu::from_bytes(&mut buf, buf_len as u8), None);
|
||||
// Decode function should still consume the required amount of bytes to leave parser in a
|
||||
// good state (assuming the length in the tlv preamble is good).
|
||||
assert_eq!(buf.remaining(), 0);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn roundtrip() {
|
||||
let mut buf = bytes::BytesMut::new();
|
||||
|
||||
let hello_src = super::Ihu::new(
|
||||
16.into(),
|
||||
400,
|
||||
Some(Ipv6Addr::new(156, 5646, 4164, 1236, 872, 960, 10, 844).into()),
|
||||
);
|
||||
hello_src.write_bytes(&mut buf);
|
||||
let buf_len = buf.len();
|
||||
let decoded = super::Ihu::from_bytes(&mut buf, buf_len as u8);
|
||||
|
||||
assert_eq!(Some(hello_src), decoded);
|
||||
assert_eq!(buf.remaining(), 0);
|
||||
}
|
||||
}
|
||||
301
mycelium/src/babel/route_request.rs
Normal file
301
mycelium/src/babel/route_request.rs
Normal file
@@ -0,0 +1,301 @@
|
||||
use std::net::{IpAddr, Ipv4Addr, Ipv6Addr};
|
||||
|
||||
use bytes::{Buf, BufMut};
|
||||
use tracing::trace;
|
||||
|
||||
use crate::subnet::Subnet;
|
||||
|
||||
use super::{AE_IPV4, AE_IPV6, AE_IPV6_LL, AE_WILDCARD};
|
||||
|
||||
/// Base wire size of a [`RouteRequest`] without variable length address encoding.
|
||||
const ROUTE_REQUEST_BASE_WIRE_SIZE: u8 = 3;
|
||||
|
||||
/// Seqno request TLV body as defined in https://datatracker.ietf.org/doc/html/rfc8966#name-route-request
|
||||
#[derive(Debug, Clone, PartialEq)]
|
||||
pub struct RouteRequest {
|
||||
/// The prefix being requested
|
||||
prefix: Option<Subnet>,
|
||||
/// The requests' generation
|
||||
generation: u8,
|
||||
}
|
||||
|
||||
impl RouteRequest {
|
||||
/// Creates a new `RouteRequest` for the given [`prefix`]. If no [`prefix`] is given, a full
|
||||
/// route table dumb in requested.
|
||||
///
|
||||
/// [`prefix`]: Subnet
|
||||
pub fn new(prefix: Option<Subnet>, generation: u8) -> Self {
|
||||
Self { prefix, generation }
|
||||
}
|
||||
|
||||
/// Return the [`prefix`](Subnet) associated with this `RouteRequest`.
|
||||
pub fn prefix(&self) -> Option<Subnet> {
|
||||
self.prefix
|
||||
}
|
||||
|
||||
/// Return the generation of the `RouteRequest`, which is the amount of times it has been
|
||||
/// forwarded already.
|
||||
pub fn generation(&self) -> u8 {
|
||||
self.generation
|
||||
}
|
||||
|
||||
/// Increment the generation of the `RouteRequest`.
|
||||
pub fn inc_generation(&mut self) {
|
||||
self.generation += 1
|
||||
}
|
||||
|
||||
/// Calculates the size on the wire of this `RouteRequest`.
|
||||
pub fn wire_size(&self) -> u8 {
|
||||
ROUTE_REQUEST_BASE_WIRE_SIZE
|
||||
+ (if let Some(prefix) = self.prefix {
|
||||
prefix.prefix_len().div_ceil(8)
|
||||
} else {
|
||||
0
|
||||
})
|
||||
}
|
||||
|
||||
/// Construct a `RouteRequest` from wire bytes.
|
||||
///
|
||||
/// # Panics
|
||||
///
|
||||
/// This function will panic if there are insufficient bytes present in the provided buffer to
|
||||
/// decode a complete `RouteRequest`.
|
||||
pub fn from_bytes(src: &mut bytes::BytesMut, len: u8) -> Option<Self> {
|
||||
let generation = src.get_u8();
|
||||
let ae = src.get_u8();
|
||||
let plen = src.get_u8();
|
||||
|
||||
let prefix_size = plen.div_ceil(8) as usize;
|
||||
|
||||
let prefix_ip = match ae {
|
||||
AE_WILDCARD => None,
|
||||
AE_IPV4 => {
|
||||
if plen > 32 {
|
||||
return None;
|
||||
}
|
||||
let mut raw_ip = [0; 4];
|
||||
raw_ip[..prefix_size].copy_from_slice(&src[..prefix_size]);
|
||||
src.advance(prefix_size);
|
||||
Some(Ipv4Addr::from(raw_ip).into())
|
||||
}
|
||||
AE_IPV6 => {
|
||||
if plen > 128 {
|
||||
return None;
|
||||
}
|
||||
let mut raw_ip = [0; 16];
|
||||
raw_ip[..prefix_size].copy_from_slice(&src[..prefix_size]);
|
||||
src.advance(prefix_size);
|
||||
Some(Ipv6Addr::from(raw_ip).into())
|
||||
}
|
||||
AE_IPV6_LL => {
|
||||
if plen != 64 {
|
||||
return None;
|
||||
}
|
||||
let mut raw_ip = [0; 16];
|
||||
raw_ip[0] = 0xfe;
|
||||
raw_ip[1] = 0x80;
|
||||
raw_ip[8..].copy_from_slice(&src[..8]);
|
||||
src.advance(8);
|
||||
Some(Ipv6Addr::from(raw_ip).into())
|
||||
}
|
||||
_ => {
|
||||
// Invalid AE type, skip reamining data and ignore
|
||||
trace!("Invalid AE type in route_request packet, drop packet");
|
||||
src.advance(len as usize - 3);
|
||||
return None;
|
||||
}
|
||||
};
|
||||
|
||||
let prefix = prefix_ip.and_then(|prefix| Subnet::new(prefix, plen).ok());
|
||||
|
||||
trace!("Read route_request tlv body");
|
||||
|
||||
Some(RouteRequest { prefix, generation })
|
||||
}
|
||||
|
||||
/// Encode this `RouteRequest` tlv as part of a packet.
|
||||
pub fn write_bytes(&self, dst: &mut bytes::BytesMut) {
|
||||
dst.put_u8(self.generation);
|
||||
if let Some(prefix) = self.prefix {
|
||||
dst.put_u8(match prefix.address() {
|
||||
IpAddr::V4(_) => AE_IPV4,
|
||||
IpAddr::V6(_) => AE_IPV6,
|
||||
});
|
||||
dst.put_u8(prefix.prefix_len());
|
||||
let prefix_len = prefix.prefix_len().div_ceil(8) as usize;
|
||||
match prefix.address() {
|
||||
IpAddr::V4(ip) => dst.put_slice(&ip.octets()[..prefix_len]),
|
||||
IpAddr::V6(ip) => dst.put_slice(&ip.octets()[..prefix_len]),
|
||||
}
|
||||
} else {
|
||||
dst.put_u8(AE_WILDCARD);
|
||||
// Prefix len MUST be 0 for wildcard requests
|
||||
dst.put_u8(0);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use std::net::{Ipv4Addr, Ipv6Addr};
|
||||
|
||||
use bytes::Buf;
|
||||
|
||||
use crate::subnet::Subnet;
|
||||
|
||||
#[test]
|
||||
fn encoding() {
|
||||
let mut buf = bytes::BytesMut::new();
|
||||
|
||||
let rr = super::RouteRequest {
|
||||
prefix: Some(
|
||||
Subnet::new(Ipv6Addr::new(512, 25, 26, 27, 28, 0, 0, 29).into(), 64)
|
||||
.expect("64 is a valid IPv6 prefix size; qed"),
|
||||
),
|
||||
generation: 2,
|
||||
};
|
||||
|
||||
rr.write_bytes(&mut buf);
|
||||
|
||||
assert_eq!(buf.len(), 11);
|
||||
assert_eq!(buf[..11], [2, 2, 64, 2, 0, 0, 25, 0, 26, 0, 27]);
|
||||
|
||||
let mut buf = bytes::BytesMut::new();
|
||||
|
||||
let rr = super::RouteRequest {
|
||||
prefix: Some(
|
||||
Subnet::new(Ipv4Addr::new(10, 101, 4, 1).into(), 32)
|
||||
.expect("32 is a valid IPv4 prefix size; qed"),
|
||||
),
|
||||
generation: 3,
|
||||
};
|
||||
|
||||
rr.write_bytes(&mut buf);
|
||||
|
||||
assert_eq!(buf.len(), 7);
|
||||
assert_eq!(buf[..7], [3, 1, 32, 10, 101, 4, 1]);
|
||||
|
||||
let mut buf = bytes::BytesMut::new();
|
||||
|
||||
let rr = super::RouteRequest {
|
||||
prefix: None,
|
||||
generation: 0,
|
||||
};
|
||||
|
||||
rr.write_bytes(&mut buf);
|
||||
|
||||
assert_eq!(buf.len(), 3);
|
||||
assert_eq!(buf[..3], [0, 0, 0]);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn decoding() {
|
||||
let mut buf = bytes::BytesMut::from(&[12, 0, 0][..]);
|
||||
|
||||
let rr = super::RouteRequest {
|
||||
prefix: None,
|
||||
generation: 12,
|
||||
};
|
||||
|
||||
let buf_len = buf.len();
|
||||
assert_eq!(
|
||||
super::RouteRequest::from_bytes(&mut buf, buf_len as u8),
|
||||
Some(rr)
|
||||
);
|
||||
assert_eq!(buf.remaining(), 0);
|
||||
|
||||
let mut buf = bytes::BytesMut::from(&[24, 1, 24, 10, 15, 19][..]);
|
||||
|
||||
let rr = super::RouteRequest {
|
||||
prefix: Some(
|
||||
Subnet::new(Ipv4Addr::new(10, 15, 19, 0).into(), 24)
|
||||
.expect("24 is a valid IPv4 prefix size; qed"),
|
||||
),
|
||||
generation: 24,
|
||||
};
|
||||
|
||||
let buf_len = buf.len();
|
||||
assert_eq!(
|
||||
super::RouteRequest::from_bytes(&mut buf, buf_len as u8),
|
||||
Some(rr)
|
||||
);
|
||||
assert_eq!(buf.remaining(), 0);
|
||||
|
||||
let mut buf = bytes::BytesMut::from(&[7, 2, 64, 0, 10, 0, 20, 0, 30, 0, 40][..]);
|
||||
|
||||
let rr = super::RouteRequest {
|
||||
prefix: Some(
|
||||
Subnet::new(Ipv6Addr::new(10, 20, 30, 40, 0, 0, 0, 0).into(), 64)
|
||||
.expect("64 is a valid IPv6 prefix size; qed"),
|
||||
),
|
||||
generation: 7,
|
||||
};
|
||||
|
||||
let buf_len = buf.len();
|
||||
assert_eq!(
|
||||
super::RouteRequest::from_bytes(&mut buf, buf_len as u8),
|
||||
Some(rr)
|
||||
);
|
||||
assert_eq!(buf.remaining(), 0);
|
||||
|
||||
let mut buf = bytes::BytesMut::from(&[4, 3, 64, 0, 10, 0, 20, 0, 30, 0, 40][..]);
|
||||
|
||||
let rr = super::RouteRequest {
|
||||
prefix: Some(
|
||||
Subnet::new(Ipv6Addr::new(0xfe80, 0, 0, 0, 10, 20, 30, 40).into(), 64)
|
||||
.expect("64 is a valid IPv6 prefix size; qed"),
|
||||
),
|
||||
generation: 4,
|
||||
};
|
||||
|
||||
let buf_len = buf.len();
|
||||
assert_eq!(
|
||||
super::RouteRequest::from_bytes(&mut buf, buf_len as u8),
|
||||
Some(rr)
|
||||
);
|
||||
assert_eq!(buf.remaining(), 0);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn decode_ignores_invalid_ae_encoding() {
|
||||
// AE 4 as it is the first one which should be used in protocol extension, causing this
|
||||
// test to fail if we forget to update something
|
||||
let mut buf = bytes::BytesMut::from(
|
||||
&[
|
||||
0, 4, 64, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19, 20, 21,
|
||||
][..],
|
||||
);
|
||||
|
||||
let buf_len = buf.len();
|
||||
|
||||
assert_eq!(
|
||||
super::RouteRequest::from_bytes(&mut buf, buf_len as u8),
|
||||
None
|
||||
);
|
||||
// Decode function should still consume the required amount of bytes to leave parser in a
|
||||
// good state (assuming the length in the tlv preamble is good).
|
||||
assert_eq!(buf.remaining(), 0);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn roundtrip() {
|
||||
let mut buf = bytes::BytesMut::new();
|
||||
|
||||
let seqno_src = super::RouteRequest::new(
|
||||
Some(
|
||||
Subnet::new(
|
||||
Ipv6Addr::new(0x21f, 0x4025, 0xabcd, 0xdead, 0, 0, 0, 0).into(),
|
||||
64,
|
||||
)
|
||||
.expect("64 is a valid IPv6 prefix size; qed"),
|
||||
),
|
||||
27,
|
||||
);
|
||||
seqno_src.write_bytes(&mut buf);
|
||||
let buf_len = buf.len();
|
||||
let decoded = super::RouteRequest::from_bytes(&mut buf, buf_len as u8);
|
||||
|
||||
assert_eq!(Some(seqno_src), decoded);
|
||||
assert_eq!(buf.remaining(), 0);
|
||||
}
|
||||
}
|
||||
356
mycelium/src/babel/seqno_request.rs
Normal file
356
mycelium/src/babel/seqno_request.rs
Normal file
@@ -0,0 +1,356 @@
|
||||
use std::{
|
||||
net::{IpAddr, Ipv4Addr, Ipv6Addr},
|
||||
num::NonZeroU8,
|
||||
};
|
||||
|
||||
use bytes::{Buf, BufMut};
|
||||
use tracing::{debug, trace};
|
||||
|
||||
use crate::{router_id::RouterId, sequence_number::SeqNo, subnet::Subnet};
|
||||
|
||||
use super::{AE_IPV4, AE_IPV6, AE_IPV6_LL, AE_WILDCARD};
|
||||
|
||||
/// The default HOP COUNT value used in new SeqNo requests, as per https://datatracker.ietf.org/doc/html/rfc8966#section-3.8.2.1
|
||||
// SAFETY: value is not zero.
|
||||
const DEFAULT_HOP_COUNT: NonZeroU8 = NonZeroU8::new(64).unwrap();
|
||||
|
||||
/// Base wire size of a [`SeqNoRequest`] without variable length address encoding.
|
||||
const SEQNO_REQUEST_BASE_WIRE_SIZE: u8 = 6 + RouterId::BYTE_SIZE as u8;
|
||||
|
||||
/// Seqno request TLV body as defined in https://datatracker.ietf.org/doc/html/rfc8966#name-seqno-request
|
||||
#[derive(Debug, Clone, PartialEq)]
|
||||
pub struct SeqNoRequest {
|
||||
/// The sequence number that is being requested.
|
||||
seqno: SeqNo,
|
||||
/// The maximum number of times this TLV may be forwarded, plus 1.
|
||||
hop_count: NonZeroU8,
|
||||
/// The router id that is being requested.
|
||||
router_id: RouterId,
|
||||
/// The prefix being requested
|
||||
prefix: Subnet,
|
||||
}
|
||||
|
||||
impl SeqNoRequest {
|
||||
/// Create a new `SeqNoRequest` for the given [prefix](Subnet) advertised by the [`RouterId`],
|
||||
/// with the required new [`SeqNo`].
|
||||
pub fn new(seqno: SeqNo, router_id: RouterId, prefix: Subnet) -> SeqNoRequest {
|
||||
Self {
|
||||
seqno,
|
||||
hop_count: DEFAULT_HOP_COUNT,
|
||||
router_id,
|
||||
prefix,
|
||||
}
|
||||
}
|
||||
|
||||
/// Return the [`prefix`](Subnet) associated with this `SeqNoRequest`.
|
||||
pub fn prefix(&self) -> Subnet {
|
||||
self.prefix
|
||||
}
|
||||
|
||||
/// Return the [`RouterId`] associated with this `SeqNoRequest`.
|
||||
pub fn router_id(&self) -> RouterId {
|
||||
self.router_id
|
||||
}
|
||||
|
||||
/// Return the requested [`SeqNo`] associated with this `SeqNoRequest`.
|
||||
pub fn seqno(&self) -> SeqNo {
|
||||
self.seqno
|
||||
}
|
||||
|
||||
/// Get the hop count for this `SeqNoRequest`.
|
||||
pub fn hop_count(&self) -> u8 {
|
||||
self.hop_count.into()
|
||||
}
|
||||
|
||||
/// Decrement the hop count for this `SeqNoRequest`.
|
||||
///
|
||||
/// # Panics
|
||||
///
|
||||
/// This function will panic if the hop count before calling this function is 1, as that will
|
||||
/// result in a hop count of 0, which is illegal for a `SeqNoRequest`. It is up to the caller
|
||||
/// to ensure this condition holds.
|
||||
pub fn decrement_hop_count(&mut self) {
|
||||
// SAFETY: The panic from this expect is documented in the function signature.
|
||||
self.hop_count = NonZeroU8::new(self.hop_count.get() - 1)
|
||||
.expect("Decrementing a hop count of 1 is not allowed");
|
||||
}
|
||||
|
||||
/// Calculates the size on the wire of this `Update`.
|
||||
pub fn wire_size(&self) -> u8 {
|
||||
SEQNO_REQUEST_BASE_WIRE_SIZE + self.prefix.prefix_len().div_ceil(8)
|
||||
// TODO: Wildcard should be encoded differently
|
||||
}
|
||||
|
||||
/// Construct a `SeqNoRequest` from wire bytes.
|
||||
///
|
||||
/// # Panics
|
||||
///
|
||||
/// This function will panic if there are insufficient bytes present in the provided buffer to
|
||||
/// decode a complete `SeqNoRequest`.
|
||||
pub fn from_bytes(src: &mut bytes::BytesMut, len: u8) -> Option<Self> {
|
||||
let ae = src.get_u8();
|
||||
let plen = src.get_u8();
|
||||
let seqno = src.get_u16().into();
|
||||
let hop_count = src.get_u8();
|
||||
// Read "reserved" value, we assume this is 0
|
||||
let _ = src.get_u8();
|
||||
|
||||
let mut router_id_bytes = [0u8; RouterId::BYTE_SIZE];
|
||||
router_id_bytes.copy_from_slice(&src[..RouterId::BYTE_SIZE]);
|
||||
src.advance(RouterId::BYTE_SIZE);
|
||||
|
||||
let router_id = RouterId::from(router_id_bytes);
|
||||
|
||||
let prefix_size = plen.div_ceil(8) as usize;
|
||||
|
||||
let prefix = match ae {
|
||||
AE_WILDCARD => {
|
||||
if plen != 0 {
|
||||
return None;
|
||||
}
|
||||
// TODO: this is a temporary placeholder until we figure out how to handle this
|
||||
Ipv6Addr::new(0, 0, 0, 0, 0, 0, 0, 0).into()
|
||||
}
|
||||
AE_IPV4 => {
|
||||
if plen > 32 {
|
||||
return None;
|
||||
}
|
||||
let mut raw_ip = [0; 4];
|
||||
raw_ip[..prefix_size].copy_from_slice(&src[..prefix_size]);
|
||||
src.advance(prefix_size);
|
||||
Ipv4Addr::from(raw_ip).into()
|
||||
}
|
||||
AE_IPV6 => {
|
||||
if plen > 128 {
|
||||
return None;
|
||||
}
|
||||
let mut raw_ip = [0; 16];
|
||||
raw_ip[..prefix_size].copy_from_slice(&src[..prefix_size]);
|
||||
src.advance(prefix_size);
|
||||
Ipv6Addr::from(raw_ip).into()
|
||||
}
|
||||
AE_IPV6_LL => {
|
||||
if plen != 64 {
|
||||
return None;
|
||||
}
|
||||
let mut raw_ip = [0; 16];
|
||||
raw_ip[0] = 0xfe;
|
||||
raw_ip[1] = 0x80;
|
||||
raw_ip[8..].copy_from_slice(&src[..8]);
|
||||
src.advance(8);
|
||||
Ipv6Addr::from(raw_ip).into()
|
||||
}
|
||||
_ => {
|
||||
// Invalid AE type, skip reamining data and ignore
|
||||
trace!("Invalid AE type in seqno_request packet, drop packet");
|
||||
src.advance(len as usize - 46);
|
||||
return None;
|
||||
}
|
||||
};
|
||||
|
||||
let prefix = Subnet::new(prefix, plen).ok()?;
|
||||
|
||||
trace!("Read seqno_request tlv body");
|
||||
|
||||
// Make sure hop_count is valid
|
||||
let hop_count = if let Some(hc) = NonZeroU8::new(hop_count) {
|
||||
hc
|
||||
} else {
|
||||
debug!("Dropping seqno_request as hop_count field is set to 0");
|
||||
return None;
|
||||
};
|
||||
|
||||
Some(SeqNoRequest {
|
||||
seqno,
|
||||
hop_count,
|
||||
router_id,
|
||||
prefix,
|
||||
})
|
||||
}
|
||||
|
||||
/// Encode this `SeqNoRequest` tlv as part of a packet.
|
||||
pub fn write_bytes(&self, dst: &mut bytes::BytesMut) {
|
||||
dst.put_u8(match self.prefix.address() {
|
||||
IpAddr::V4(_) => AE_IPV4,
|
||||
IpAddr::V6(_) => AE_IPV6,
|
||||
});
|
||||
dst.put_u8(self.prefix.prefix_len());
|
||||
dst.put_u16(self.seqno.into());
|
||||
dst.put_u8(self.hop_count.into());
|
||||
// Write "reserved" value.
|
||||
dst.put_u8(0);
|
||||
dst.put_slice(&self.router_id.as_bytes()[..]);
|
||||
let prefix_len = self.prefix.prefix_len().div_ceil(8) as usize;
|
||||
match self.prefix.address() {
|
||||
IpAddr::V4(ip) => dst.put_slice(&ip.octets()[..prefix_len]),
|
||||
IpAddr::V6(ip) => dst.put_slice(&ip.octets()[..prefix_len]),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use std::{
|
||||
net::{Ipv4Addr, Ipv6Addr},
|
||||
num::NonZeroU8,
|
||||
};
|
||||
|
||||
use crate::{router_id::RouterId, subnet::Subnet};
|
||||
use bytes::Buf;
|
||||
|
||||
#[test]
|
||||
fn encoding() {
|
||||
let mut buf = bytes::BytesMut::new();
|
||||
|
||||
let snr = super::SeqNoRequest {
|
||||
seqno: 17.into(),
|
||||
hop_count: NonZeroU8::new(64).unwrap(),
|
||||
prefix: Subnet::new(Ipv6Addr::new(512, 25, 26, 27, 28, 0, 0, 29).into(), 64)
|
||||
.expect("64 is a valid IPv6 prefix size; qed"),
|
||||
router_id: RouterId::from([1u8; RouterId::BYTE_SIZE]),
|
||||
};
|
||||
|
||||
snr.write_bytes(&mut buf);
|
||||
|
||||
assert_eq!(buf.len(), 54);
|
||||
assert_eq!(
|
||||
buf[..54],
|
||||
[
|
||||
2, 64, 0, 17, 64, 0, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1,
|
||||
1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 2, 0, 0, 25, 0, 26, 0, 27,
|
||||
]
|
||||
);
|
||||
|
||||
let mut buf = bytes::BytesMut::new();
|
||||
|
||||
let snr = super::SeqNoRequest {
|
||||
seqno: 170.into(),
|
||||
hop_count: NonZeroU8::new(111).unwrap(),
|
||||
prefix: Subnet::new(Ipv4Addr::new(10, 101, 4, 1).into(), 32)
|
||||
.expect("32 is a valid IPv4 prefix size; qed"),
|
||||
router_id: RouterId::from([2u8; RouterId::BYTE_SIZE]),
|
||||
};
|
||||
|
||||
snr.write_bytes(&mut buf);
|
||||
|
||||
assert_eq!(buf.len(), 50);
|
||||
assert_eq!(
|
||||
buf[..50],
|
||||
[
|
||||
1, 32, 0, 170, 111, 0, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2,
|
||||
2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 10, 101, 4, 1,
|
||||
]
|
||||
);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn decoding() {
|
||||
let mut buf = bytes::BytesMut::from(
|
||||
&[
|
||||
0, 0, 0, 0, 1, 0, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3,
|
||||
3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3,
|
||||
][..],
|
||||
);
|
||||
|
||||
let snr = super::SeqNoRequest {
|
||||
hop_count: NonZeroU8::new(1).unwrap(),
|
||||
seqno: 0.into(),
|
||||
prefix: Subnet::new(Ipv6Addr::new(0, 0, 0, 0, 0, 0, 0, 0).into(), 0)
|
||||
.expect("0 is a valid IPv6 prefix size; qed"),
|
||||
router_id: RouterId::from([3u8; RouterId::BYTE_SIZE]),
|
||||
};
|
||||
|
||||
let buf_len = buf.len();
|
||||
assert_eq!(
|
||||
super::SeqNoRequest::from_bytes(&mut buf, buf_len as u8),
|
||||
Some(snr)
|
||||
);
|
||||
assert_eq!(buf.remaining(), 0);
|
||||
|
||||
let mut buf = bytes::BytesMut::from(
|
||||
&[
|
||||
3, 64, 0, 42, 232, 0, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4,
|
||||
4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 0, 10, 0, 20, 0, 30, 0,
|
||||
40,
|
||||
][..],
|
||||
);
|
||||
|
||||
let snr = super::SeqNoRequest {
|
||||
seqno: 42.into(),
|
||||
hop_count: NonZeroU8::new(232).unwrap(),
|
||||
prefix: Subnet::new(Ipv6Addr::new(0xfe80, 0, 0, 0, 10, 20, 30, 40).into(), 64)
|
||||
.expect("92 is a valid IPv6 prefix size; qed"),
|
||||
router_id: RouterId::from([4u8; RouterId::BYTE_SIZE]),
|
||||
};
|
||||
|
||||
let buf_len = buf.len();
|
||||
assert_eq!(
|
||||
super::SeqNoRequest::from_bytes(&mut buf, buf_len as u8),
|
||||
Some(snr)
|
||||
);
|
||||
assert_eq!(buf.remaining(), 0);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn decode_ignores_invalid_ae_encoding() {
|
||||
// AE 4 as it is the first one which should be used in protocol extension, causing this
|
||||
// test to fail if we forget to update something
|
||||
let mut buf = bytes::BytesMut::from(
|
||||
&[
|
||||
4, 64, 0, 0, 44, 0, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5,
|
||||
5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 6, 7, 8, 9, 10, 11, 12,
|
||||
13, 14, 15, 16, 17, 18, 19, 20, 21,
|
||||
][..],
|
||||
);
|
||||
|
||||
let buf_len = buf.len();
|
||||
|
||||
assert_eq!(
|
||||
super::SeqNoRequest::from_bytes(&mut buf, buf_len as u8),
|
||||
None
|
||||
);
|
||||
// Decode function should still consume the required amount of bytes to leave parser in a
|
||||
// good state (assuming the length in the tlv preamble is good).
|
||||
assert_eq!(buf.remaining(), 0);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn decode_ignores_invalid_hop_count() {
|
||||
// Set all flag bits, only allowed bits should be set on the decoded value
|
||||
let mut buf = bytes::BytesMut::from(
|
||||
&[
|
||||
3, 64, 92, 0, 0, 0, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4,
|
||||
4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 0, 10, 0, 20, 0, 30, 0,
|
||||
40,
|
||||
][..],
|
||||
);
|
||||
|
||||
let buf_len = buf.len();
|
||||
assert_eq!(
|
||||
super::SeqNoRequest::from_bytes(&mut buf, buf_len as u8),
|
||||
None
|
||||
);
|
||||
assert_eq!(buf.remaining(), 0);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn roundtrip() {
|
||||
let mut buf = bytes::BytesMut::new();
|
||||
|
||||
let seqno_src = super::SeqNoRequest::new(
|
||||
64.into(),
|
||||
RouterId::from([6; RouterId::BYTE_SIZE]),
|
||||
Subnet::new(
|
||||
Ipv6Addr::new(0x21f, 0x4025, 0xabcd, 0xdead, 0, 0, 0, 0).into(),
|
||||
64,
|
||||
)
|
||||
.expect("64 is a valid IPv6 prefix size; qed"),
|
||||
);
|
||||
seqno_src.write_bytes(&mut buf);
|
||||
let buf_len = buf.len();
|
||||
let decoded = super::SeqNoRequest::from_bytes(&mut buf, buf_len as u8);
|
||||
|
||||
assert_eq!(Some(seqno_src), decoded);
|
||||
assert_eq!(buf.remaining(), 0);
|
||||
}
|
||||
}
|
||||
72
mycelium/src/babel/tlv.rs
Normal file
72
mycelium/src/babel/tlv.rs
Normal file
@@ -0,0 +1,72 @@
|
||||
pub use super::{hello::Hello, ihu::Ihu, update::Update};
|
||||
use super::{route_request::RouteRequest, SeqNoRequest};
|
||||
|
||||
/// A single `Tlv` in a babel packet body.
|
||||
#[derive(Debug, Clone, PartialEq)]
|
||||
pub enum Tlv {
|
||||
/// Hello Tlv type.
|
||||
Hello(Hello),
|
||||
/// Ihu Tlv type.
|
||||
Ihu(Ihu),
|
||||
/// Update Tlv type.
|
||||
Update(Update),
|
||||
/// RouteRequest Tlv type.
|
||||
RouteRequest(RouteRequest),
|
||||
/// SeqNoRequest Tlv type
|
||||
SeqNoRequest(SeqNoRequest),
|
||||
}
|
||||
|
||||
impl Tlv {
|
||||
/// Calculate the size on the wire for this `Tlv`. This DOES NOT included the TLV header size
|
||||
/// (2 bytes).
|
||||
pub fn wire_size(&self) -> u8 {
|
||||
match self {
|
||||
Self::Hello(hello) => hello.wire_size(),
|
||||
Self::Ihu(ihu) => ihu.wire_size(),
|
||||
Self::Update(update) => update.wire_size(),
|
||||
Self::RouteRequest(route_request) => route_request.wire_size(),
|
||||
Self::SeqNoRequest(seqno_request) => seqno_request.wire_size(),
|
||||
}
|
||||
}
|
||||
|
||||
/// Encode this `Tlv` as part of a packet.
|
||||
pub fn write_bytes(&self, dst: &mut bytes::BytesMut) {
|
||||
match self {
|
||||
Self::Hello(hello) => hello.write_bytes(dst),
|
||||
Self::Ihu(ihu) => ihu.write_bytes(dst),
|
||||
Self::Update(update) => update.write_bytes(dst),
|
||||
Self::RouteRequest(route_request) => route_request.write_bytes(dst),
|
||||
Self::SeqNoRequest(seqno_request) => seqno_request.write_bytes(dst),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl From<SeqNoRequest> for Tlv {
|
||||
fn from(v: SeqNoRequest) -> Self {
|
||||
Self::SeqNoRequest(v)
|
||||
}
|
||||
}
|
||||
|
||||
impl From<RouteRequest> for Tlv {
|
||||
fn from(v: RouteRequest) -> Self {
|
||||
Self::RouteRequest(v)
|
||||
}
|
||||
}
|
||||
|
||||
impl From<Update> for Tlv {
|
||||
fn from(v: Update) -> Self {
|
||||
Self::Update(v)
|
||||
}
|
||||
}
|
||||
|
||||
impl From<Ihu> for Tlv {
|
||||
fn from(v: Ihu) -> Self {
|
||||
Self::Ihu(v)
|
||||
}
|
||||
}
|
||||
|
||||
impl From<Hello> for Tlv {
|
||||
fn from(v: Hello) -> Self {
|
||||
Self::Hello(v)
|
||||
}
|
||||
}
|
||||
385
mycelium/src/babel/update.rs
Normal file
385
mycelium/src/babel/update.rs
Normal file
@@ -0,0 +1,385 @@
|
||||
//! The babel [Update TLV](https://datatracker.ietf.org/doc/html/rfc8966#name-update).
|
||||
|
||||
use std::{
|
||||
net::{IpAddr, Ipv4Addr, Ipv6Addr},
|
||||
time::Duration,
|
||||
};
|
||||
|
||||
use bytes::{Buf, BufMut};
|
||||
use tracing::trace;
|
||||
|
||||
use crate::{metric::Metric, router_id::RouterId, sequence_number::SeqNo, subnet::Subnet};
|
||||
|
||||
use super::{AE_IPV4, AE_IPV6, AE_IPV6_LL, AE_WILDCARD};
|
||||
|
||||
/// Flag bit indicating an [`Update`] TLV establishes a new default prefix.
|
||||
#[allow(dead_code)]
|
||||
const UPDATE_FLAG_PREFIX: u8 = 0x80;
|
||||
/// Flag bit indicating an [`Update`] TLV establishes a new default router-id.
|
||||
#[allow(dead_code)]
|
||||
const UPDATE_FLAG_ROUTER_ID: u8 = 0x40;
|
||||
/// Mask to apply to [`Update`] flags, leaving only valid flags.
|
||||
const FLAG_MASK: u8 = 0b1100_0000;
|
||||
|
||||
/// Base wire size of an [`Update`] without variable length address encoding.
|
||||
const UPDATE_BASE_WIRE_SIZE: u8 = 10 + RouterId::BYTE_SIZE as u8;
|
||||
|
||||
/// Update TLV body as defined in https://datatracker.ietf.org/doc/html/rfc8966#name-update.
|
||||
#[derive(Debug, Clone, PartialEq)]
|
||||
pub struct Update {
|
||||
/// Flags set in the TLV.
|
||||
flags: u8,
|
||||
/// Upper bound in centiseconds after which a new `Update` is sent. Must not be 0.
|
||||
interval: u16,
|
||||
/// Senders sequence number.
|
||||
seqno: SeqNo,
|
||||
/// Senders metric for this route.
|
||||
metric: Metric,
|
||||
/// The [`Subnet`] contained in this update. An update packet itself can contain any allowed
|
||||
/// subnet.
|
||||
subnet: Subnet,
|
||||
/// Router id of the sender. Importantly this is not part of the update itself, though we do
|
||||
/// transmit it for now as such.
|
||||
router_id: RouterId,
|
||||
}
|
||||
|
||||
impl Update {
|
||||
/// Create a new `Update`.
|
||||
pub fn new(
|
||||
interval: Duration,
|
||||
seqno: SeqNo,
|
||||
metric: Metric,
|
||||
subnet: Subnet,
|
||||
router_id: RouterId,
|
||||
) -> Self {
|
||||
let interval_centiseconds = (interval.as_millis() / 10) as u16;
|
||||
Self {
|
||||
// No flags used for now
|
||||
flags: 0,
|
||||
interval: interval_centiseconds,
|
||||
seqno,
|
||||
metric,
|
||||
subnet,
|
||||
router_id,
|
||||
}
|
||||
}
|
||||
|
||||
/// Returns the [`SeqNo`] of the sender of this `Update`.
|
||||
pub fn seqno(&self) -> SeqNo {
|
||||
self.seqno
|
||||
}
|
||||
|
||||
/// Return the [`Metric`] of the sender for this route in the `Update`.
|
||||
pub fn metric(&self) -> Metric {
|
||||
self.metric
|
||||
}
|
||||
|
||||
/// Return the [`Subnet`] in this `Update.`
|
||||
pub fn subnet(&self) -> Subnet {
|
||||
self.subnet
|
||||
}
|
||||
|
||||
/// Return the [`router-id`](PublicKey) of the router who advertised this [`Prefix`](IpAddr).
|
||||
pub fn router_id(&self) -> RouterId {
|
||||
self.router_id
|
||||
}
|
||||
|
||||
/// Calculates the size on the wire of this `Update`.
|
||||
pub fn wire_size(&self) -> u8 {
|
||||
let address_bytes = self.subnet.prefix_len().div_ceil(8);
|
||||
UPDATE_BASE_WIRE_SIZE + address_bytes
|
||||
}
|
||||
|
||||
/// Get the time until a new `Update` for the [`Subnet`] is received at the latest.
|
||||
pub fn interval(&self) -> Duration {
|
||||
// Interval is expressed as centiseconds on the wire.
|
||||
Duration::from_millis(self.interval as u64 * 10)
|
||||
}
|
||||
|
||||
/// Construct an `Update` from wire bytes.
|
||||
///
|
||||
/// # Panics
|
||||
///
|
||||
/// This function will panic if there are insufficient bytes present in the provided buffer to
|
||||
/// decode a complete `Update`.
|
||||
pub fn from_bytes(src: &mut bytes::BytesMut, len: u8) -> Option<Self> {
|
||||
let ae = src.get_u8();
|
||||
let flags = src.get_u8() & FLAG_MASK;
|
||||
let plen = src.get_u8();
|
||||
// Read "omitted" value, we assume this is 0
|
||||
let _ = src.get_u8();
|
||||
let interval = src.get_u16();
|
||||
let seqno = src.get_u16().into();
|
||||
let metric = src.get_u16().into();
|
||||
let prefix_size = plen.div_ceil(8) as usize;
|
||||
let prefix = match ae {
|
||||
AE_WILDCARD => {
|
||||
if prefix_size != 0 {
|
||||
return None;
|
||||
}
|
||||
// TODO: this is a temporary placeholder until we figure out how to handle this
|
||||
Ipv6Addr::new(0, 0, 0, 0, 0, 0, 0, 0).into()
|
||||
}
|
||||
AE_IPV4 => {
|
||||
if plen > 32 {
|
||||
return None;
|
||||
}
|
||||
let mut raw_ip = [0; 4];
|
||||
raw_ip[..prefix_size].copy_from_slice(&src[..prefix_size]);
|
||||
src.advance(prefix_size);
|
||||
Ipv4Addr::from(raw_ip).into()
|
||||
}
|
||||
AE_IPV6 => {
|
||||
if plen > 128 {
|
||||
return None;
|
||||
}
|
||||
let mut raw_ip = [0; 16];
|
||||
raw_ip[..prefix_size].copy_from_slice(&src[..prefix_size]);
|
||||
src.advance(prefix_size);
|
||||
Ipv6Addr::from(raw_ip).into()
|
||||
}
|
||||
AE_IPV6_LL => {
|
||||
if plen != 64 {
|
||||
return None;
|
||||
}
|
||||
let mut raw_ip = [0; 16];
|
||||
raw_ip[0] = 0xfe;
|
||||
raw_ip[1] = 0x80;
|
||||
raw_ip[8..].copy_from_slice(&src[..8]);
|
||||
src.advance(8);
|
||||
Ipv6Addr::from(raw_ip).into()
|
||||
}
|
||||
_ => {
|
||||
// Invalid AE type, skip reamining data and ignore
|
||||
trace!("Invalid AE type in update packet, drop packet");
|
||||
src.advance(len as usize - 10);
|
||||
return None;
|
||||
}
|
||||
};
|
||||
|
||||
let subnet = Subnet::new(prefix, plen).ok()?;
|
||||
|
||||
let mut router_id_bytes = [0u8; RouterId::BYTE_SIZE];
|
||||
router_id_bytes.copy_from_slice(&src[..RouterId::BYTE_SIZE]);
|
||||
src.advance(RouterId::BYTE_SIZE);
|
||||
|
||||
let router_id = RouterId::from(router_id_bytes);
|
||||
|
||||
trace!("Read update tlv body");
|
||||
|
||||
Some(Update {
|
||||
flags,
|
||||
interval,
|
||||
seqno,
|
||||
metric,
|
||||
subnet,
|
||||
router_id,
|
||||
})
|
||||
}
|
||||
|
||||
/// Encode this `Update` tlv as part of a packet.
|
||||
pub fn write_bytes(&self, dst: &mut bytes::BytesMut) {
|
||||
dst.put_u8(match self.subnet.address() {
|
||||
IpAddr::V4(_) => AE_IPV4,
|
||||
IpAddr::V6(_) => AE_IPV6,
|
||||
});
|
||||
dst.put_u8(self.flags);
|
||||
dst.put_u8(self.subnet.prefix_len());
|
||||
// Write "omitted" value, currently not used in our encoding scheme.
|
||||
dst.put_u8(0);
|
||||
dst.put_u16(self.interval);
|
||||
dst.put_u16(self.seqno.into());
|
||||
dst.put_u16(self.metric.into());
|
||||
let prefix_len = self.subnet.prefix_len().div_ceil(8) as usize;
|
||||
match self.subnet.address() {
|
||||
IpAddr::V4(ip) => dst.put_slice(&ip.octets()[..prefix_len]),
|
||||
IpAddr::V6(ip) => dst.put_slice(&ip.octets()[..prefix_len]),
|
||||
}
|
||||
dst.put_slice(&self.router_id.as_bytes()[..])
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use std::{
|
||||
net::{Ipv4Addr, Ipv6Addr},
|
||||
time::Duration,
|
||||
};
|
||||
|
||||
use crate::{router_id::RouterId, subnet::Subnet};
|
||||
use bytes::Buf;
|
||||
|
||||
#[test]
|
||||
fn encoding() {
|
||||
let mut buf = bytes::BytesMut::new();
|
||||
|
||||
let ihu = super::Update {
|
||||
flags: 0b1100_0000,
|
||||
interval: 400,
|
||||
seqno: 17.into(),
|
||||
metric: 25.into(),
|
||||
subnet: Subnet::new(Ipv6Addr::new(512, 25, 26, 27, 28, 0, 0, 29).into(), 64)
|
||||
.expect("64 is a valid IPv6 prefix size; qed"),
|
||||
router_id: RouterId::from([1u8; RouterId::BYTE_SIZE]),
|
||||
};
|
||||
|
||||
ihu.write_bytes(&mut buf);
|
||||
|
||||
assert_eq!(buf.len(), 58);
|
||||
assert_eq!(
|
||||
buf[..58],
|
||||
[
|
||||
2, 192, 64, 0, 1, 144, 0, 17, 0, 25, 2, 0, 0, 25, 0, 26, 0, 27, 1, 1, 1, 1, 1, 1,
|
||||
1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1,
|
||||
1, 1, 1, 1, 1, 1
|
||||
]
|
||||
);
|
||||
|
||||
let mut buf = bytes::BytesMut::new();
|
||||
|
||||
let ihu = super::Update {
|
||||
flags: 0b0000_0000,
|
||||
interval: 600,
|
||||
seqno: 170.into(),
|
||||
metric: 256.into(),
|
||||
subnet: Subnet::new(Ipv4Addr::new(10, 101, 4, 1).into(), 23)
|
||||
.expect("23 is a valid IPv4 prefix size; qed"),
|
||||
router_id: RouterId::from([2u8; RouterId::BYTE_SIZE]),
|
||||
};
|
||||
|
||||
ihu.write_bytes(&mut buf);
|
||||
|
||||
assert_eq!(buf.len(), 53);
|
||||
assert_eq!(
|
||||
buf[..53],
|
||||
[
|
||||
1, 0, 23, 0, 2, 88, 0, 170, 1, 0, 10, 101, 4, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2,
|
||||
2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2
|
||||
]
|
||||
);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn decoding() {
|
||||
let mut buf = bytes::BytesMut::from(
|
||||
&[
|
||||
0, 64, 0, 0, 0, 100, 0, 70, 2, 0, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3,
|
||||
3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3,
|
||||
][..],
|
||||
);
|
||||
|
||||
let ihu = super::Update {
|
||||
flags: 0b0100_0000,
|
||||
interval: 100,
|
||||
seqno: 70.into(),
|
||||
metric: 512.into(),
|
||||
subnet: Subnet::new(Ipv6Addr::new(0, 0, 0, 0, 0, 0, 0, 0).into(), 0)
|
||||
.expect("0 is a valid IPv6 prefix size; qed"),
|
||||
router_id: RouterId::from([3u8; RouterId::BYTE_SIZE]),
|
||||
};
|
||||
|
||||
let buf_len = buf.len();
|
||||
assert_eq!(
|
||||
super::Update::from_bytes(&mut buf, buf_len as u8),
|
||||
Some(ihu)
|
||||
);
|
||||
assert_eq!(buf.remaining(), 0);
|
||||
|
||||
let mut buf = bytes::BytesMut::from(
|
||||
&[
|
||||
3, 0, 64, 0, 3, 232, 0, 42, 3, 1, 0, 10, 0, 20, 0, 30, 0, 40, 4, 4, 4, 4, 4, 4, 4,
|
||||
4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4,
|
||||
4, 4, 4, 4, 4,
|
||||
][..],
|
||||
);
|
||||
|
||||
let ihu = super::Update {
|
||||
flags: 0b0000_0000,
|
||||
interval: 1000,
|
||||
seqno: 42.into(),
|
||||
metric: 769.into(),
|
||||
subnet: Subnet::new(Ipv6Addr::new(0xfe80, 0, 0, 0, 10, 20, 30, 40).into(), 64)
|
||||
.expect("92 is a valid IPv6 prefix size; qed"),
|
||||
router_id: RouterId::from([4u8; RouterId::BYTE_SIZE]),
|
||||
};
|
||||
|
||||
let buf_len = buf.len();
|
||||
assert_eq!(
|
||||
super::Update::from_bytes(&mut buf, buf_len as u8),
|
||||
Some(ihu)
|
||||
);
|
||||
assert_eq!(buf.remaining(), 0);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn decode_ignores_invalid_ae_encoding() {
|
||||
// AE 4 as it is the first one which should be used in protocol extension, causing this
|
||||
// test to fail if we forget to update something
|
||||
let mut buf = bytes::BytesMut::from(
|
||||
&[
|
||||
4, 0, 64, 0, 0, 44, 2, 0, 0, 10, 10, 5, 0, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16,
|
||||
17, 18, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5,
|
||||
5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5,
|
||||
][..],
|
||||
);
|
||||
|
||||
let buf_len = buf.len();
|
||||
|
||||
assert_eq!(super::Update::from_bytes(&mut buf, buf_len as u8), None);
|
||||
// Decode function should still consume the required amount of bytes to leave parser in a
|
||||
// good state (assuming the length in the tlv preamble is good).
|
||||
assert_eq!(buf.remaining(), 0);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn decode_ignores_invalid_flag_bits() {
|
||||
// Set all flag bits, only allowed bits should be set on the decoded value
|
||||
let mut buf = bytes::BytesMut::from(
|
||||
&[
|
||||
3, 255, 64, 0, 3, 232, 0, 42, 3, 1, 0, 10, 0, 20, 0, 30, 0, 40, 4, 4, 4, 4, 4, 4,
|
||||
4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4,
|
||||
4, 4, 4, 4, 4, 4,
|
||||
][..],
|
||||
);
|
||||
|
||||
let ihu = super::Update {
|
||||
flags: super::UPDATE_FLAG_PREFIX | super::UPDATE_FLAG_ROUTER_ID,
|
||||
interval: 1000,
|
||||
seqno: 42.into(),
|
||||
metric: 769.into(),
|
||||
subnet: Subnet::new(Ipv6Addr::new(0xfe80, 0, 0, 0, 10, 20, 30, 40).into(), 64)
|
||||
.expect("92 is a valid IPv6 prefix size; qed"),
|
||||
router_id: RouterId::from([4u8; RouterId::BYTE_SIZE]),
|
||||
};
|
||||
|
||||
let buf_len = buf.len();
|
||||
assert_eq!(
|
||||
super::Update::from_bytes(&mut buf, buf_len as u8),
|
||||
Some(ihu)
|
||||
);
|
||||
assert_eq!(buf.remaining(), 0);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn roundtrip() {
|
||||
let mut buf = bytes::BytesMut::new();
|
||||
|
||||
let hello_src = super::Update::new(
|
||||
Duration::from_secs(64),
|
||||
10.into(),
|
||||
25.into(),
|
||||
Subnet::new(
|
||||
Ipv6Addr::new(0x21f, 0x4025, 0xabcd, 0xdead, 0, 0, 0, 0).into(),
|
||||
64,
|
||||
)
|
||||
.expect("64 is a valid IPv6 prefix size; qed"),
|
||||
RouterId::from([6; RouterId::BYTE_SIZE]),
|
||||
);
|
||||
hello_src.write_bytes(&mut buf);
|
||||
let buf_len = buf.len();
|
||||
let decoded = super::Update::from_bytes(&mut buf, buf_len as u8);
|
||||
|
||||
assert_eq!(Some(hello_src), decoded);
|
||||
assert_eq!(buf.remaining(), 0);
|
||||
}
|
||||
}
|
||||
338
mycelium/src/cdn.rs
Normal file
338
mycelium/src/cdn.rs
Normal file
@@ -0,0 +1,338 @@
|
||||
use std::path::PathBuf;
|
||||
|
||||
use aes_gcm::{aead::Aead, KeyInit};
|
||||
use axum::{
|
||||
extract::{Query, State},
|
||||
http::{HeaderMap, StatusCode},
|
||||
routing::get,
|
||||
Router,
|
||||
};
|
||||
use axum_extra::extract::Host;
|
||||
use futures::{stream::FuturesUnordered, StreamExt};
|
||||
use reqwest::header::CONTENT_TYPE;
|
||||
use tokio::net::TcpListener;
|
||||
use tokio_util::sync::CancellationToken;
|
||||
use tracing::{debug, error, info, warn};
|
||||
|
||||
/// Cdn functionality. Urls of specific format lead to donwnlaoding of metadata from the registry,
|
||||
/// and serving of chunks.
|
||||
pub struct Cdn {
|
||||
cache: PathBuf,
|
||||
cancel_token: CancellationToken,
|
||||
}
|
||||
|
||||
/// Cache for reconstructed blocks
|
||||
#[derive(Clone)]
|
||||
struct Cache {
|
||||
base: PathBuf,
|
||||
}
|
||||
|
||||
impl Cdn {
|
||||
pub fn new(cache: PathBuf) -> Self {
|
||||
let cancel_token = CancellationToken::new();
|
||||
Self {
|
||||
cache,
|
||||
cancel_token,
|
||||
}
|
||||
}
|
||||
|
||||
/// Start the Cdn server. This future runs until the server is stopped.
|
||||
pub fn start(&self, listener: TcpListener) -> Result<(), Box<dyn std::error::Error>> {
|
||||
let state = Cache {
|
||||
base: self.cache.clone(),
|
||||
};
|
||||
|
||||
if !self.cache.exists() {
|
||||
info!(dir = %self.cache.display(), "Creating cache dir");
|
||||
std::fs::create_dir(&self.cache)?;
|
||||
}
|
||||
|
||||
if !self.cache.is_dir() {
|
||||
return Err("Cache dir is not a directory".into());
|
||||
}
|
||||
|
||||
let router = Router::new().route("/", get(cdn)).with_state(state);
|
||||
|
||||
let cancel_token = self.cancel_token.clone();
|
||||
|
||||
tokio::spawn(async {
|
||||
axum::serve(listener, router)
|
||||
.with_graceful_shutdown(cancel_token.cancelled_owned())
|
||||
.await
|
||||
.map_err(|err| {
|
||||
warn!(%err, "Cdn server error");
|
||||
})
|
||||
});
|
||||
|
||||
Ok(())
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Debug, serde::Deserialize)]
|
||||
struct DecryptionKeyQuery {
|
||||
key: Option<String>,
|
||||
}
|
||||
|
||||
#[tracing::instrument(level = tracing::Level::DEBUG, skip(cache))]
|
||||
async fn cdn(
|
||||
Host(host): Host,
|
||||
Query(query): Query<DecryptionKeyQuery>,
|
||||
State(cache): State<Cache>,
|
||||
) -> Result<(HeaderMap, Vec<u8>), StatusCode> {
|
||||
debug!("Received request at {host}");
|
||||
let mut parts = host.split('.');
|
||||
let prefix = parts
|
||||
.next()
|
||||
.expect("Splitting a String always yields at least 1 result; Qed.");
|
||||
if prefix.len() != 32 {
|
||||
return Err(StatusCode::BAD_REQUEST);
|
||||
}
|
||||
|
||||
let mut hash = [0; 16];
|
||||
faster_hex::hex_decode(prefix.as_bytes(), &mut hash).map_err(|_| StatusCode::BAD_REQUEST)?;
|
||||
|
||||
let registry_url = parts.collect::<Vec<_>>().join(".");
|
||||
|
||||
let decryption_key = if let Some(query_key) = query.key {
|
||||
let mut key = [0; 16];
|
||||
faster_hex::hex_decode(query_key.as_bytes(), &mut key)
|
||||
.map_err(|_| StatusCode::BAD_REQUEST)?;
|
||||
Some(key)
|
||||
} else {
|
||||
None
|
||||
};
|
||||
|
||||
let meta = load_meta(registry_url.clone(), hash, decryption_key).await?;
|
||||
debug!("Metadata loaded");
|
||||
|
||||
let mut headers = HeaderMap::new();
|
||||
match meta {
|
||||
cdn_meta::Metadata::File(file) => {
|
||||
//
|
||||
if let Some(mime) = file.mime {
|
||||
debug!(%mime, "Setting mime type");
|
||||
headers.append(
|
||||
CONTENT_TYPE,
|
||||
mime.parse().map_err(|_| {
|
||||
warn!("Not serving file with unprocessable mime type");
|
||||
StatusCode::UNPROCESSABLE_ENTITY
|
||||
})?,
|
||||
);
|
||||
}
|
||||
|
||||
// File recombination
|
||||
let mut content = vec![];
|
||||
for block in file.blocks {
|
||||
content.extend_from_slice(cache.fetch_block(&block).await?.as_slice());
|
||||
}
|
||||
Ok((headers, content))
|
||||
}
|
||||
cdn_meta::Metadata::Directory(dir) => {
|
||||
let mut out = r#"
|
||||
<!DOCTYPE html>
|
||||
|
||||
<html i18n-values="dir:textdirection;lang:language">
|
||||
<head>
|
||||
<meta charset="utf-8">
|
||||
</head>
|
||||
<body>
|
||||
<ul>"#
|
||||
.to_string();
|
||||
headers.append(
|
||||
CONTENT_TYPE,
|
||||
"text/html"
|
||||
.parse()
|
||||
.expect("Can parse \"text/html\" to content-type"),
|
||||
);
|
||||
for (file_hash, encryption_key) in dir.files {
|
||||
let meta = load_meta(registry_url.clone(), file_hash, encryption_key).await?;
|
||||
let name = match meta {
|
||||
cdn_meta::Metadata::File(file) => file.name,
|
||||
cdn_meta::Metadata::Directory(dir) => dir.name,
|
||||
};
|
||||
out.push_str(&format!(
|
||||
"<li><a href=\"http://{}.{registry_url}/?key={}\">{name}</a></li>\n",
|
||||
faster_hex::hex_string(&file_hash),
|
||||
&encryption_key
|
||||
.map(|ek| faster_hex::hex_string(&ek))
|
||||
.unwrap_or_else(String::new),
|
||||
));
|
||||
}
|
||||
|
||||
out.push_str("</ul></body></html>");
|
||||
Ok((headers, out.into()))
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
/// Load a metadata blob from a metadata repository.
|
||||
async fn load_meta(
|
||||
registry_url: String,
|
||||
hash: cdn_meta::Hash,
|
||||
encryption_key: Option<cdn_meta::Hash>,
|
||||
) -> Result<cdn_meta::Metadata, StatusCode> {
|
||||
let mut r_url = reqwest::Url::parse(&format!("http://{registry_url}")).map_err(|err| {
|
||||
error!(%err, "Could not parse registry URL");
|
||||
StatusCode::INTERNAL_SERVER_ERROR
|
||||
})?;
|
||||
let hex_hash = faster_hex::hex_string(&hash);
|
||||
r_url.set_path(&format!("/api/v1/metadata/{hex_hash}"));
|
||||
r_url.set_scheme("http").map_err(|_| {
|
||||
error!("Could not set HTTP scheme");
|
||||
StatusCode::INTERNAL_SERVER_ERROR
|
||||
})?;
|
||||
|
||||
debug!(url = %r_url, "Fetching chunk");
|
||||
|
||||
let metadata_reply = reqwest::get(r_url).await.map_err(|err| {
|
||||
error!(%err, "Could not load metadata from registry");
|
||||
StatusCode::INTERNAL_SERVER_ERROR
|
||||
})?;
|
||||
|
||||
// TODO: Should we just check if status code is success here?
|
||||
if metadata_reply.status() != StatusCode::OK {
|
||||
debug!(
|
||||
status = %metadata_reply.status(),
|
||||
"Registry replied with non-OK status code"
|
||||
);
|
||||
return Err(metadata_reply.status());
|
||||
}
|
||||
|
||||
let encrypted_metadata = metadata_reply.bytes().await.map_err(|err| {
|
||||
error!(%err, "Could not load metadata response from registry");
|
||||
StatusCode::INTERNAL_SERVER_ERROR
|
||||
})?;
|
||||
|
||||
let metadata = if let Some(encryption_key) = encryption_key {
|
||||
if encrypted_metadata.len() < 12 {
|
||||
debug!("Attempting to decrypt metadata with inufficient size");
|
||||
return Err(StatusCode::UNPROCESSABLE_ENTITY);
|
||||
}
|
||||
|
||||
let decryptor = aes_gcm::Aes128Gcm::new(&encryption_key.into());
|
||||
let plaintext = decryptor
|
||||
.decrypt(
|
||||
encrypted_metadata[encrypted_metadata.len() - 12..].into(),
|
||||
&encrypted_metadata[..encrypted_metadata.len() - 12],
|
||||
)
|
||||
.map_err(|_| {
|
||||
warn!("Decryption of block failed");
|
||||
// Either the decryption key is wrong or the blob is corrupt, we assume the
|
||||
// registry is not a fault so the decryption key is wrong, which is a user error.
|
||||
StatusCode::UNPROCESSABLE_ENTITY
|
||||
})?;
|
||||
|
||||
plaintext
|
||||
} else {
|
||||
encrypted_metadata.into()
|
||||
};
|
||||
|
||||
// If the metadata is not decodable, this is not really our fault, but also not the necessarily
|
||||
// the users fault.
|
||||
let (meta, consumed) =
|
||||
cdn_meta::Metadata::from_binary(&metadata).map_err(|_| StatusCode::UNPROCESSABLE_ENTITY)?;
|
||||
if consumed != metadata.len() {
|
||||
warn!(
|
||||
metadata_length = metadata.len(),
|
||||
consumed, "Trailing binary metadata which wasn't decoded"
|
||||
);
|
||||
}
|
||||
|
||||
Ok(meta)
|
||||
}
|
||||
|
||||
impl Drop for Cdn {
|
||||
fn drop(&mut self) {
|
||||
self.cancel_token.cancel();
|
||||
}
|
||||
}
|
||||
|
||||
/// Download a shard from a 0-db.
|
||||
async fn download_shard(
|
||||
location: &cdn_meta::Location,
|
||||
key: &[u8],
|
||||
) -> Result<Vec<u8>, Box<dyn std::error::Error>> {
|
||||
let client = redis::Client::open(format!("redis://{}", location.host))?;
|
||||
let mut con = client.get_multiplexed_async_connection().await?;
|
||||
|
||||
redis::cmd("SELECT")
|
||||
.arg(&location.namespace)
|
||||
.query_async::<()>(&mut con)
|
||||
.await?;
|
||||
|
||||
Ok(redis::cmd("GET").arg(key).query_async(&mut con).await?)
|
||||
}
|
||||
|
||||
impl Cache {
|
||||
async fn fetch_block(&self, block: &cdn_meta::Block) -> Result<Vec<u8>, StatusCode> {
|
||||
let mut cached_file_path = self.base.clone();
|
||||
cached_file_path.push(faster_hex::hex_string(&block.encrypted_hash));
|
||||
// If we have the file in cache, just open it, load it, and return from there.
|
||||
if cached_file_path.exists() {
|
||||
return tokio::fs::read(&cached_file_path).await.map_err(|err| {
|
||||
error!(%err, "Could not load cached file");
|
||||
StatusCode::INTERNAL_SERVER_ERROR
|
||||
});
|
||||
}
|
||||
|
||||
// File is not in cache, download and save
|
||||
|
||||
// TODO: Rank based on expected latency
|
||||
// FIXME: Only download the required amount
|
||||
let mut shard_stream = block
|
||||
.shards
|
||||
.iter()
|
||||
.enumerate()
|
||||
.map(|(i, loc)| async move { (i, download_shard(loc, &block.encrypted_hash).await) })
|
||||
.collect::<FuturesUnordered<_>>();
|
||||
let mut shards = vec![None; block.shards.len()];
|
||||
while let Some((idx, shard)) = shard_stream.next().await {
|
||||
let shard = shard.map_err(|err| {
|
||||
warn!(err, "Could not load shard");
|
||||
StatusCode::INTERNAL_SERVER_ERROR
|
||||
})?;
|
||||
shards[idx] = Some(shard);
|
||||
}
|
||||
// recombine
|
||||
let encoder = reed_solomon_erasure::galois_8::ReedSolomon::new(
|
||||
block.required_shards as usize,
|
||||
block.shards.len() - block.required_shards as usize,
|
||||
)
|
||||
.map_err(|err| {
|
||||
error!(%err, "Failed to construct erausre codec");
|
||||
StatusCode::INTERNAL_SERVER_ERROR
|
||||
})?;
|
||||
|
||||
encoder.reconstruct_data(&mut shards).map_err(|err| {
|
||||
error!(%err, "Shard recombination failed");
|
||||
StatusCode::INTERNAL_SERVER_ERROR
|
||||
})?;
|
||||
|
||||
// SAFETY: Since decoding was succesfull, the first shards (data shards) must be
|
||||
// Option::Some
|
||||
let mut encrypted_data = shards
|
||||
.into_iter()
|
||||
.map(Option::unwrap)
|
||||
.take(block.required_shards as usize)
|
||||
.flatten()
|
||||
.collect::<Vec<_>>();
|
||||
|
||||
let padding_len = encrypted_data[encrypted_data.len() - 1] as usize;
|
||||
encrypted_data.resize(encrypted_data.len() - padding_len, 0);
|
||||
|
||||
let decryptor = aes_gcm::Aes128Gcm::new(&block.content_hash.into());
|
||||
let c = decryptor
|
||||
.decrypt(&block.nonce.into(), encrypted_data.as_slice())
|
||||
.map_err(|err| {
|
||||
warn!(%err, "Decryption of content block failed");
|
||||
StatusCode::UNPROCESSABLE_ENTITY
|
||||
})?;
|
||||
|
||||
// Save file to cache, this is not critical if it fails
|
||||
if let Err(err) = tokio::fs::write(&cached_file_path, &c).await {
|
||||
warn!(%err, "Could not write block to cache");
|
||||
};
|
||||
|
||||
Ok(c)
|
||||
}
|
||||
}
|
||||
158
mycelium/src/connection.rs
Normal file
158
mycelium/src/connection.rs
Normal file
@@ -0,0 +1,158 @@
|
||||
use std::{io, net::SocketAddr, pin::Pin};
|
||||
|
||||
use tokio::{
|
||||
io::{AsyncRead, AsyncWrite},
|
||||
net::TcpStream,
|
||||
};
|
||||
|
||||
mod tracked;
|
||||
pub use tracked::Tracked;
|
||||
|
||||
#[cfg(feature = "private-network")]
|
||||
mod tls;
|
||||
|
||||
/// Cost to add to the peer_link_cost for "local processing", when peers are connected over IPv6.
|
||||
///
|
||||
/// The current peer link cost is calculated from a HELLO rtt. This is great to measure link
|
||||
/// latency, since packets are processed in order. However, on local idle links, this value will
|
||||
/// likely be 0 since we round down (from the amount of ms it took to process), which does not
|
||||
/// accurately reflect the fact that there is in fact a cost associated with using a peer, even on
|
||||
/// these local links.
|
||||
const PACKET_PROCESSING_COST_IP6_TCP: u16 = 10;
|
||||
|
||||
/// Cost to add to the peer_link_cost for "local processing", when peers are connected over IPv6.
|
||||
///
|
||||
/// This is similar to [`PACKET_PROCESSING_COST_IP6`], but slightly higher so we skew towards IPv6
|
||||
/// connections if peers are connected over both IPv4 and IPv6.
|
||||
const PACKET_PROCESSING_COST_IP4_TCP: u16 = 15;
|
||||
|
||||
// TODO
|
||||
const PACKET_PROCESSING_COST_IP6_QUIC: u16 = 7;
|
||||
// TODO
|
||||
const PACKET_PROCESSING_COST_IP4_QUIC: u16 = 12;
|
||||
|
||||
pub trait Connection: AsyncRead + AsyncWrite {
|
||||
/// Get an identifier for this connection, which shows details about the remote
|
||||
fn identifier(&self) -> Result<String, io::Error>;
|
||||
|
||||
/// The static cost of using this connection
|
||||
fn static_link_cost(&self) -> Result<u16, io::Error>;
|
||||
}
|
||||
|
||||
/// A wrapper around a quic send and quic receive stream, implementing the [`Connection`] trait.
|
||||
pub struct Quic {
|
||||
tx: quinn::SendStream,
|
||||
rx: quinn::RecvStream,
|
||||
remote: SocketAddr,
|
||||
}
|
||||
|
||||
impl Quic {
|
||||
/// Create a new wrapper around Quic streams.
|
||||
pub fn new(tx: quinn::SendStream, rx: quinn::RecvStream, remote: SocketAddr) -> Self {
|
||||
Quic { tx, rx, remote }
|
||||
}
|
||||
}
|
||||
|
||||
impl Connection for TcpStream {
|
||||
fn identifier(&self) -> Result<String, io::Error> {
|
||||
Ok(format!(
|
||||
"TCP {} <-> {}",
|
||||
self.local_addr()?,
|
||||
self.peer_addr()?
|
||||
))
|
||||
}
|
||||
|
||||
fn static_link_cost(&self) -> Result<u16, io::Error> {
|
||||
Ok(match self.peer_addr()? {
|
||||
SocketAddr::V4(_) => PACKET_PROCESSING_COST_IP4_TCP,
|
||||
SocketAddr::V6(ip) if ip.ip().to_ipv4_mapped().is_some() => {
|
||||
PACKET_PROCESSING_COST_IP4_TCP
|
||||
}
|
||||
SocketAddr::V6(_) => PACKET_PROCESSING_COST_IP6_TCP,
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
impl AsyncRead for Quic {
|
||||
#[inline]
|
||||
fn poll_read(
|
||||
mut self: std::pin::Pin<&mut Self>,
|
||||
cx: &mut std::task::Context<'_>,
|
||||
buf: &mut tokio::io::ReadBuf<'_>,
|
||||
) -> std::task::Poll<io::Result<()>> {
|
||||
Pin::new(&mut self.rx).poll_read(cx, buf)
|
||||
}
|
||||
}
|
||||
|
||||
impl AsyncWrite for Quic {
|
||||
#[inline]
|
||||
fn poll_write(
|
||||
mut self: Pin<&mut Self>,
|
||||
cx: &mut std::task::Context<'_>,
|
||||
buf: &[u8],
|
||||
) -> std::task::Poll<Result<usize, io::Error>> {
|
||||
Pin::new(&mut self.tx)
|
||||
.poll_write(cx, buf)
|
||||
.map_err(From::from)
|
||||
}
|
||||
|
||||
#[inline]
|
||||
fn poll_flush(
|
||||
mut self: Pin<&mut Self>,
|
||||
cx: &mut std::task::Context<'_>,
|
||||
) -> std::task::Poll<Result<(), io::Error>> {
|
||||
Pin::new(&mut self.tx).poll_flush(cx)
|
||||
}
|
||||
|
||||
#[inline]
|
||||
fn poll_shutdown(
|
||||
mut self: Pin<&mut Self>,
|
||||
cx: &mut std::task::Context<'_>,
|
||||
) -> std::task::Poll<Result<(), io::Error>> {
|
||||
Pin::new(&mut self.tx).poll_shutdown(cx)
|
||||
}
|
||||
|
||||
#[inline]
|
||||
fn poll_write_vectored(
|
||||
mut self: Pin<&mut Self>,
|
||||
cx: &mut std::task::Context<'_>,
|
||||
bufs: &[io::IoSlice<'_>],
|
||||
) -> std::task::Poll<Result<usize, io::Error>> {
|
||||
Pin::new(&mut self.tx).poll_write_vectored(cx, bufs)
|
||||
}
|
||||
|
||||
#[inline]
|
||||
fn is_write_vectored(&self) -> bool {
|
||||
self.tx.is_write_vectored()
|
||||
}
|
||||
}
|
||||
|
||||
impl Connection for Quic {
|
||||
fn identifier(&self) -> Result<String, io::Error> {
|
||||
Ok(format!("QUIC -> {}", self.remote))
|
||||
}
|
||||
|
||||
fn static_link_cost(&self) -> Result<u16, io::Error> {
|
||||
Ok(match self.remote {
|
||||
SocketAddr::V4(_) => PACKET_PROCESSING_COST_IP4_QUIC,
|
||||
SocketAddr::V6(ip) if ip.ip().to_ipv4_mapped().is_some() => {
|
||||
PACKET_PROCESSING_COST_IP4_QUIC
|
||||
}
|
||||
SocketAddr::V6(_) => PACKET_PROCESSING_COST_IP6_QUIC,
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
use tokio::io::DuplexStream;
|
||||
|
||||
#[cfg(test)]
|
||||
impl Connection for DuplexStream {
|
||||
fn identifier(&self) -> Result<String, io::Error> {
|
||||
Ok("Memory pipe".to_string())
|
||||
}
|
||||
|
||||
fn static_link_cost(&self) -> Result<u16, io::Error> {
|
||||
Ok(1)
|
||||
}
|
||||
}
|
||||
23
mycelium/src/connection/tls.rs
Normal file
23
mycelium/src/connection/tls.rs
Normal file
@@ -0,0 +1,23 @@
|
||||
use std::{io, net::SocketAddr};
|
||||
|
||||
use tokio::net::TcpStream;
|
||||
|
||||
impl super::Connection for tokio_openssl::SslStream<TcpStream> {
|
||||
fn identifier(&self) -> Result<String, io::Error> {
|
||||
Ok(format!(
|
||||
"TLS {} <-> {}",
|
||||
self.get_ref().local_addr()?,
|
||||
self.get_ref().peer_addr()?
|
||||
))
|
||||
}
|
||||
|
||||
fn static_link_cost(&self) -> Result<u16, io::Error> {
|
||||
Ok(match self.get_ref().peer_addr()? {
|
||||
SocketAddr::V4(_) => super::PACKET_PROCESSING_COST_IP4_TCP,
|
||||
SocketAddr::V6(ip) if ip.ip().to_ipv4_mapped().is_some() => {
|
||||
super::PACKET_PROCESSING_COST_IP4_TCP
|
||||
}
|
||||
SocketAddr::V6(_) => super::PACKET_PROCESSING_COST_IP6_TCP,
|
||||
})
|
||||
}
|
||||
}
|
||||
120
mycelium/src/connection/tracked.rs
Normal file
120
mycelium/src/connection/tracked.rs
Normal file
@@ -0,0 +1,120 @@
|
||||
use std::{
|
||||
pin::Pin,
|
||||
sync::{
|
||||
atomic::{AtomicU64, Ordering},
|
||||
Arc,
|
||||
},
|
||||
task::Poll,
|
||||
};
|
||||
|
||||
use tokio::io::{AsyncRead, AsyncWrite};
|
||||
|
||||
use super::Connection;
|
||||
|
||||
/// Wrapper which keeps track of how much bytes have been read and written from a connection.
|
||||
pub struct Tracked<C> {
|
||||
/// Bytes read counter
|
||||
read: Arc<AtomicU64>,
|
||||
/// Bytes written counter
|
||||
write: Arc<AtomicU64>,
|
||||
/// Underlying connection we are measuring
|
||||
con: C,
|
||||
}
|
||||
|
||||
impl<C> Tracked<C>
|
||||
where
|
||||
C: Connection + Unpin,
|
||||
{
|
||||
/// Create a new instance of a tracked connections. Counters are passed in so they can be
|
||||
/// reused accross connections.
|
||||
pub fn new(read: Arc<AtomicU64>, write: Arc<AtomicU64>, con: C) -> Self {
|
||||
Self { read, write, con }
|
||||
}
|
||||
}
|
||||
|
||||
impl<C> Connection for Tracked<C>
|
||||
where
|
||||
C: Connection + Unpin,
|
||||
{
|
||||
#[inline]
|
||||
fn identifier(&self) -> Result<String, std::io::Error> {
|
||||
self.con.identifier()
|
||||
}
|
||||
|
||||
#[inline]
|
||||
fn static_link_cost(&self) -> Result<u16, std::io::Error> {
|
||||
self.con.static_link_cost()
|
||||
}
|
||||
}
|
||||
|
||||
impl<C> AsyncRead for Tracked<C>
|
||||
where
|
||||
C: AsyncRead + Unpin,
|
||||
{
|
||||
#[inline]
|
||||
fn poll_read(
|
||||
mut self: std::pin::Pin<&mut Self>,
|
||||
cx: &mut std::task::Context<'_>,
|
||||
buf: &mut tokio::io::ReadBuf<'_>,
|
||||
) -> std::task::Poll<std::io::Result<()>> {
|
||||
let start_len = buf.filled().len();
|
||||
let res = Pin::new(&mut self.con).poll_read(cx, buf);
|
||||
if let Poll::Ready(Ok(())) = res {
|
||||
self.read
|
||||
.fetch_add((buf.filled().len() - start_len) as u64, Ordering::Relaxed);
|
||||
}
|
||||
res
|
||||
}
|
||||
}
|
||||
|
||||
impl<C> AsyncWrite for Tracked<C>
|
||||
where
|
||||
C: AsyncWrite + Unpin,
|
||||
{
|
||||
#[inline]
|
||||
fn poll_write(
|
||||
mut self: Pin<&mut Self>,
|
||||
cx: &mut std::task::Context<'_>,
|
||||
buf: &[u8],
|
||||
) -> Poll<Result<usize, std::io::Error>> {
|
||||
let res = Pin::new(&mut self.con).poll_write(cx, buf);
|
||||
if let Poll::Ready(Ok(written)) = res {
|
||||
self.write.fetch_add(written as u64, Ordering::Relaxed);
|
||||
}
|
||||
res
|
||||
}
|
||||
|
||||
#[inline]
|
||||
fn poll_flush(
|
||||
mut self: Pin<&mut Self>,
|
||||
cx: &mut std::task::Context<'_>,
|
||||
) -> Poll<Result<(), std::io::Error>> {
|
||||
Pin::new(&mut self.con).poll_flush(cx)
|
||||
}
|
||||
|
||||
#[inline]
|
||||
fn poll_shutdown(
|
||||
mut self: Pin<&mut Self>,
|
||||
cx: &mut std::task::Context<'_>,
|
||||
) -> Poll<Result<(), std::io::Error>> {
|
||||
Pin::new(&mut self.con).poll_shutdown(cx)
|
||||
}
|
||||
|
||||
#[inline]
|
||||
fn poll_write_vectored(
|
||||
mut self: Pin<&mut Self>,
|
||||
cx: &mut std::task::Context<'_>,
|
||||
bufs: &[std::io::IoSlice<'_>],
|
||||
) -> Poll<Result<usize, std::io::Error>> {
|
||||
let res = Pin::new(&mut self.con).poll_write_vectored(cx, bufs);
|
||||
if let Poll::Ready(Ok(written)) = res {
|
||||
self.write.fetch_add(written as u64, Ordering::Relaxed);
|
||||
}
|
||||
res
|
||||
}
|
||||
|
||||
#[inline]
|
||||
fn is_write_vectored(&self) -> bool {
|
||||
self.con.is_write_vectored()
|
||||
}
|
||||
}
|
||||
450
mycelium/src/crypto.rs
Normal file
450
mycelium/src/crypto.rs
Normal file
@@ -0,0 +1,450 @@
|
||||
//! Abstraction over diffie hellman, symmetric encryption, and hashing.
|
||||
|
||||
use core::fmt;
|
||||
use std::{
|
||||
error::Error,
|
||||
fmt::Display,
|
||||
net::Ipv6Addr,
|
||||
ops::{Deref, DerefMut},
|
||||
};
|
||||
|
||||
use aes_gcm::{aead::OsRng, AeadCore, AeadInPlace, Aes256Gcm, Key, KeyInit};
|
||||
use serde::{de::Visitor, Deserialize, Serialize};
|
||||
|
||||
/// Default MTU for a packet. Ideally this would not be needed and the [`PacketBuffer`] takes a
|
||||
/// const generic argument which is then expanded with the needed extra space for the buffer,
|
||||
/// however as it stands const generics can only be used standalone and not in a constant
|
||||
/// expression. This _is_ possible on nightly rust, with a feature gate (generic_const_exprs).
|
||||
const PACKET_SIZE: usize = 1400;
|
||||
|
||||
/// Size of an AES_GCM tag in bytes.
|
||||
const AES_TAG_SIZE: usize = 16;
|
||||
|
||||
/// Size of an AES_GCM nonce in bytes.
|
||||
const AES_NONCE_SIZE: usize = 12;
|
||||
|
||||
/// Size of user defined data header. This header will be part of the encrypted data.
|
||||
const DATA_HEADER_SIZE: usize = 4;
|
||||
|
||||
/// Size of a `PacketBuffer`.
|
||||
const PACKET_BUFFER_SIZE: usize = PACKET_SIZE + AES_TAG_SIZE + AES_NONCE_SIZE + DATA_HEADER_SIZE;
|
||||
|
||||
/// A public key used as part of Diffie Hellman key exchange. It is derived from a [`SecretKey`].
|
||||
#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)]
|
||||
pub struct PublicKey(x25519_dalek::PublicKey);
|
||||
|
||||
/// A secret used as part of Diffie Hellman key exchange.
|
||||
///
|
||||
/// This type intentionally does not implement or derive [`Debug`] to avoid accidentally leaking
|
||||
/// secrets in logs.
|
||||
#[derive(Clone)]
|
||||
pub struct SecretKey(x25519_dalek::StaticSecret);
|
||||
|
||||
/// A statically computed secret from a [`SecretKey`] and a [`PublicKey`].
|
||||
///
|
||||
/// This type intentionally does not implement or derive [`Debug`] to avoid accidentally leaking
|
||||
/// secrets in logs.
|
||||
#[derive(Clone)]
|
||||
pub struct SharedSecret([u8; 32]);
|
||||
|
||||
/// A buffer for packets. This holds enough space to encrypt a packet in place without
|
||||
/// reallocating.
|
||||
///
|
||||
/// Internally, the buffer is created with an additional header. Because this header is part of the
|
||||
/// encrypted content, it is not included in the global version set by the main packet header. As
|
||||
/// such, an internal version is included.
|
||||
pub struct PacketBuffer {
|
||||
buf: Vec<u8>,
|
||||
/// Amount of bytes written in the buffer
|
||||
size: usize,
|
||||
}
|
||||
|
||||
/// A reference to the header in a [`PacketBuffer`].
|
||||
pub struct PacketBufferHeader<'a> {
|
||||
data: &'a [u8; DATA_HEADER_SIZE],
|
||||
}
|
||||
|
||||
/// A mutable reference to the header in a [`PacketBuffer`].
|
||||
pub struct PacketBufferHeaderMut<'a> {
|
||||
data: &'a mut [u8; DATA_HEADER_SIZE],
|
||||
}
|
||||
|
||||
/// Opaque type indicating decryption failed.
|
||||
#[derive(Debug, Clone, Copy)]
|
||||
pub struct DecryptionError;
|
||||
|
||||
impl Display for DecryptionError {
|
||||
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
|
||||
f.write_str("Decryption failed, invalid or insufficient encrypted content for this key")
|
||||
}
|
||||
}
|
||||
|
||||
impl Error for DecryptionError {}
|
||||
|
||||
impl SecretKey {
|
||||
/// Generate a new `StaticSecret` using [`OsRng`] as an entropy source.
|
||||
pub fn new() -> Self {
|
||||
SecretKey(x25519_dalek::StaticSecret::random_from_rng(OsRng))
|
||||
}
|
||||
|
||||
/// View this `SecretKey` as a byte array.
|
||||
#[inline]
|
||||
pub fn as_bytes(&self) -> &[u8; 32] {
|
||||
self.0.as_bytes()
|
||||
}
|
||||
|
||||
/// Computes the [`SharedSecret`] from this `SecretKey` and a [`PublicKey`].
|
||||
pub fn shared_secret(&self, other: &PublicKey) -> SharedSecret {
|
||||
SharedSecret(self.0.diffie_hellman(&other.0).to_bytes())
|
||||
}
|
||||
}
|
||||
|
||||
impl Default for SecretKey {
|
||||
fn default() -> Self {
|
||||
Self::new()
|
||||
}
|
||||
}
|
||||
|
||||
impl PublicKey {
|
||||
/// Generates an [`Ipv6Addr`] from a `PublicKey`.
|
||||
///
|
||||
/// The generated address is guaranteed to be part of the `400::/7` range.
|
||||
pub fn address(&self) -> Ipv6Addr {
|
||||
let mut hasher = blake3::Hasher::new();
|
||||
hasher.update(self.as_bytes());
|
||||
let mut buf = [0; 16];
|
||||
hasher.finalize_xof().fill(&mut buf);
|
||||
// Mangle the first byte to be of the expected form. Because of the network range
|
||||
// requirement, we MUST set the third bit, and MAY set the last bit. Instead of discarding
|
||||
// the first 7 bits of the hash, use the first byte to determine if the last bit is set.
|
||||
// If there is an odd number of bits set in the first byte, set the last bit of the result.
|
||||
let lsb = buf[0].count_ones() as u8 % 2;
|
||||
buf[0] = 0x04 | lsb;
|
||||
Ipv6Addr::from(buf)
|
||||
}
|
||||
|
||||
/// Convert this `PublicKey` to a byte array.
|
||||
pub fn to_bytes(self) -> [u8; 32] {
|
||||
self.0.to_bytes()
|
||||
}
|
||||
|
||||
/// View this `PublicKey` as a byte array.
|
||||
pub fn as_bytes(&self) -> &[u8; 32] {
|
||||
self.0.as_bytes()
|
||||
}
|
||||
}
|
||||
|
||||
impl SharedSecret {
|
||||
/// Encrypt a [`PacketBuffer`] using the `SharedSecret` as key.
|
||||
///
|
||||
/// Internally, a new random nonce will be generated using the OS's crypto rng generator. This
|
||||
/// nonce is appended to the encrypted data.
|
||||
pub fn encrypt(&self, mut data: PacketBuffer) -> Vec<u8> {
|
||||
let key: Key<Aes256Gcm> = self.0.into();
|
||||
let nonce = Aes256Gcm::generate_nonce(OsRng);
|
||||
|
||||
let cipher = Aes256Gcm::new(&key);
|
||||
let tag = cipher
|
||||
.encrypt_in_place_detached(&nonce, &[], &mut data.buf[..data.size])
|
||||
.expect("Encryption can't fail; qed.");
|
||||
|
||||
data.buf[data.size..data.size + AES_TAG_SIZE].clone_from_slice(tag.as_slice());
|
||||
data.buf[data.size + AES_TAG_SIZE..data.size + AES_TAG_SIZE + AES_NONCE_SIZE]
|
||||
.clone_from_slice(&nonce);
|
||||
|
||||
data.buf.truncate(data.size + AES_NONCE_SIZE + AES_TAG_SIZE);
|
||||
|
||||
data.buf
|
||||
}
|
||||
|
||||
/// Decrypt a message previously encrypted with an equivalent `SharedSecret`. In other words, a
|
||||
/// message that was previously created by the [`SharedSecret::encrypt`] method.
|
||||
///
|
||||
/// Internally, this messages assumes that a 12 byte nonce is present at the end of the data.
|
||||
/// If the passed in data to decrypt does not contain a valid nonce, decryption fails and an
|
||||
/// opaque error is returned. As an extension to this, if the data is not of sufficient length
|
||||
/// to contain a valid nonce, an error is returned immediately.
|
||||
pub fn decrypt(&self, mut data: Vec<u8>) -> Result<PacketBuffer, DecryptionError> {
|
||||
// Make sure we have sufficient data (i.e. a nonce).
|
||||
if data.len() < AES_NONCE_SIZE + AES_TAG_SIZE + DATA_HEADER_SIZE {
|
||||
return Err(DecryptionError);
|
||||
}
|
||||
|
||||
let data_len = data.len();
|
||||
|
||||
let key: Key<Aes256Gcm> = self.0.into();
|
||||
{
|
||||
let (data, nonce) = data.split_at_mut(data_len - AES_NONCE_SIZE);
|
||||
let (data, tag) = data.split_at_mut(data.len() - AES_TAG_SIZE);
|
||||
|
||||
let cipher = Aes256Gcm::new(&key);
|
||||
cipher
|
||||
.decrypt_in_place_detached((&*nonce).into(), &[], data, (&*tag).into())
|
||||
.map_err(|_| DecryptionError)?;
|
||||
}
|
||||
|
||||
Ok(PacketBuffer {
|
||||
// We did not remove the scratch space used for TAG and NONCE.
|
||||
size: data.len() - AES_TAG_SIZE - AES_NONCE_SIZE,
|
||||
buf: data,
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
impl PacketBuffer {
|
||||
/// Create a new blank `PacketBuffer`.
|
||||
pub fn new() -> Self {
|
||||
Self {
|
||||
buf: vec![0; PACKET_BUFFER_SIZE],
|
||||
size: 0,
|
||||
}
|
||||
}
|
||||
|
||||
/// Get a reference to the packet header.
|
||||
pub fn header(&self) -> PacketBufferHeader<'_> {
|
||||
PacketBufferHeader {
|
||||
data: self.buf[..DATA_HEADER_SIZE]
|
||||
.try_into()
|
||||
.expect("Header size constant is correct; qed"),
|
||||
}
|
||||
}
|
||||
|
||||
/// Get a mutable reference to the packet header.
|
||||
pub fn header_mut(&mut self) -> PacketBufferHeaderMut<'_> {
|
||||
PacketBufferHeaderMut {
|
||||
data: <&mut [u8] as TryInto<&mut [u8; DATA_HEADER_SIZE]>>::try_into(
|
||||
&mut self.buf[..DATA_HEADER_SIZE],
|
||||
)
|
||||
.expect("Header size constant is correct; qed"),
|
||||
}
|
||||
}
|
||||
|
||||
/// Get a reference to the entire useable inner buffer.
|
||||
pub fn buffer(&self) -> &[u8] {
|
||||
let buf_end = self.buf.len() - AES_NONCE_SIZE - AES_TAG_SIZE;
|
||||
&self.buf[DATA_HEADER_SIZE..buf_end]
|
||||
}
|
||||
|
||||
/// Get a mutable reference to the entire useable internal buffer.
|
||||
pub fn buffer_mut(&mut self) -> &mut [u8] {
|
||||
let buf_end = self.buf.len() - AES_NONCE_SIZE - AES_TAG_SIZE;
|
||||
&mut self.buf[DATA_HEADER_SIZE..buf_end]
|
||||
}
|
||||
|
||||
/// Sets the amount of bytes in use by the buffer.
|
||||
pub fn set_size(&mut self, size: usize) {
|
||||
self.size = size + DATA_HEADER_SIZE;
|
||||
}
|
||||
}
|
||||
|
||||
impl Default for PacketBuffer {
|
||||
fn default() -> Self {
|
||||
Self::new()
|
||||
}
|
||||
}
|
||||
|
||||
impl From<[u8; 32]> for SecretKey {
|
||||
/// Load a secret key from a byte array.
|
||||
fn from(bytes: [u8; 32]) -> SecretKey {
|
||||
SecretKey(x25519_dalek::StaticSecret::from(bytes))
|
||||
}
|
||||
}
|
||||
|
||||
impl fmt::Display for PublicKey {
|
||||
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
|
||||
f.write_str(&faster_hex::hex_string(self.as_bytes()))
|
||||
}
|
||||
}
|
||||
|
||||
impl Serialize for PublicKey {
|
||||
fn serialize<S>(&self, serializer: S) -> Result<S::Ok, S::Error>
|
||||
where
|
||||
S: serde::Serializer,
|
||||
{
|
||||
serializer.serialize_str(&faster_hex::hex_string(self.as_bytes()))
|
||||
}
|
||||
}
|
||||
|
||||
struct PublicKeyVisitor;
|
||||
impl Visitor<'_> for PublicKeyVisitor {
|
||||
type Value = PublicKey;
|
||||
|
||||
fn expecting(&self, formatter: &mut std::fmt::Formatter) -> std::fmt::Result {
|
||||
formatter.write_str("A hex encoded public key (64 characters)")
|
||||
}
|
||||
|
||||
fn visit_str<E>(self, v: &str) -> Result<Self::Value, E>
|
||||
where
|
||||
E: serde::de::Error,
|
||||
{
|
||||
if v.len() != 64 {
|
||||
Err(E::custom("Public key is 64 characters long"))
|
||||
} else {
|
||||
let mut backing = [0; 32];
|
||||
faster_hex::hex_decode(v.as_bytes(), &mut backing)
|
||||
.map_err(|_| E::custom("PublicKey is not valid hex"))?;
|
||||
Ok(PublicKey(backing.into()))
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl<'de> Deserialize<'de> for PublicKey {
|
||||
fn deserialize<D>(deserializer: D) -> Result<Self, D::Error>
|
||||
where
|
||||
D: serde::Deserializer<'de>,
|
||||
{
|
||||
deserializer.deserialize_str(PublicKeyVisitor)
|
||||
}
|
||||
}
|
||||
|
||||
impl From<[u8; 32]> for PublicKey {
|
||||
/// Given a byte array, construct a `PublicKey`.
|
||||
fn from(bytes: [u8; 32]) -> PublicKey {
|
||||
PublicKey(x25519_dalek::PublicKey::from(bytes))
|
||||
}
|
||||
}
|
||||
|
||||
impl TryFrom<&str> for PublicKey {
|
||||
type Error = faster_hex::Error;
|
||||
|
||||
fn try_from(value: &str) -> Result<Self, Self::Error> {
|
||||
let mut output = [0u8; 32];
|
||||
faster_hex::hex_decode(value.as_bytes(), &mut output)?;
|
||||
Ok(PublicKey::from(output))
|
||||
}
|
||||
}
|
||||
|
||||
impl From<&SecretKey> for PublicKey {
|
||||
fn from(value: &SecretKey) -> Self {
|
||||
PublicKey(x25519_dalek::PublicKey::from(&value.0))
|
||||
}
|
||||
}
|
||||
|
||||
impl Deref for SharedSecret {
|
||||
type Target = [u8; 32];
|
||||
|
||||
fn deref(&self) -> &Self::Target {
|
||||
&self.0
|
||||
}
|
||||
}
|
||||
|
||||
impl Deref for PacketBuffer {
|
||||
type Target = [u8];
|
||||
|
||||
fn deref(&self) -> &Self::Target {
|
||||
&self.buf[DATA_HEADER_SIZE..self.size]
|
||||
}
|
||||
}
|
||||
|
||||
impl Deref for PacketBufferHeader<'_> {
|
||||
type Target = [u8; DATA_HEADER_SIZE];
|
||||
|
||||
fn deref(&self) -> &Self::Target {
|
||||
self.data
|
||||
}
|
||||
}
|
||||
|
||||
impl Deref for PacketBufferHeaderMut<'_> {
|
||||
type Target = [u8; DATA_HEADER_SIZE];
|
||||
|
||||
fn deref(&self) -> &Self::Target {
|
||||
self.data
|
||||
}
|
||||
}
|
||||
|
||||
impl DerefMut for PacketBufferHeaderMut<'_> {
|
||||
fn deref_mut(&mut self) -> &mut Self::Target {
|
||||
self.data
|
||||
}
|
||||
}
|
||||
|
||||
impl fmt::Debug for PacketBuffer {
|
||||
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
|
||||
f.debug_struct("PacketBuffer")
|
||||
.field("data", &"...")
|
||||
.field("len", &self.size)
|
||||
.finish()
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use super::{PacketBuffer, SecretKey, AES_NONCE_SIZE, AES_TAG_SIZE, DATA_HEADER_SIZE};
|
||||
|
||||
#[test]
|
||||
/// Test if encryption works in general. We just create some random value and encrypt it.
|
||||
/// Specifically, this will help to catch runtime panics in case AES_TAG_SIZE or AES_NONCE_SIZE
|
||||
/// don't have a proper value aligned with the underlying AES_GCM implementation.
|
||||
fn encryption_succeeds() {
|
||||
let k1 = SecretKey::new();
|
||||
let k2 = SecretKey::new();
|
||||
|
||||
let ss = k1.shared_secret(&(&k2).into());
|
||||
|
||||
let mut pb = PacketBuffer::new();
|
||||
let data = b"vnno30nv f654q364 vfsv 44"; // Random keyboard smash.
|
||||
|
||||
pb.buffer_mut()[..data.len()].copy_from_slice(data);
|
||||
pb.set_size(data.len());
|
||||
|
||||
// We only care that this does not panic.
|
||||
let res = ss.encrypt(pb);
|
||||
// At the same time, check expected size.
|
||||
assert_eq!(
|
||||
res.len(),
|
||||
data.len() + DATA_HEADER_SIZE + AES_TAG_SIZE + AES_NONCE_SIZE
|
||||
);
|
||||
}
|
||||
|
||||
#[test]
|
||||
/// Encrypt a value and then decrypt it. This makes sure the decrypt flow and encrypt flow
|
||||
/// match, and both follow the expected format. Also, we don't reuse the shared secret for
|
||||
/// decryption, but instead generate the secret again the other way round, to simulate a remote
|
||||
/// node.
|
||||
fn encrypt_decrypt_roundtrip() {
|
||||
let k1 = SecretKey::new();
|
||||
let k2 = SecretKey::new();
|
||||
|
||||
let ss1 = k1.shared_secret(&(&k2).into());
|
||||
let ss2 = k2.shared_secret(&(&k1).into());
|
||||
|
||||
// This assertion is not strictly necessary as it will be checked below implicitly.
|
||||
assert_eq!(ss1.as_slice(), ss2.as_slice());
|
||||
|
||||
let data = b"dsafjiqjo23 u2953u8 3oid fjo321j";
|
||||
let mut pb = PacketBuffer::new();
|
||||
|
||||
pb.buffer_mut()[..data.len()].copy_from_slice(data);
|
||||
pb.set_size(data.len());
|
||||
|
||||
let res = ss1.encrypt(pb);
|
||||
|
||||
let original = ss2.decrypt(res).expect("Decryption works");
|
||||
|
||||
assert_eq!(&*original, &data[..]);
|
||||
}
|
||||
|
||||
#[test]
|
||||
/// Test if PacketBufferHeaderMut actually modifies the PacketBuffer storage.
|
||||
fn modify_header() {
|
||||
let mut pb = PacketBuffer::new();
|
||||
let mut header = pb.header_mut();
|
||||
|
||||
header[0] = 1;
|
||||
header[1] = 2;
|
||||
header[2] = 3;
|
||||
header[3] = 4;
|
||||
|
||||
assert_eq!(pb.buf[..DATA_HEADER_SIZE], [1, 2, 3, 4]);
|
||||
}
|
||||
|
||||
#[test]
|
||||
/// Verify [`PacketBuffer::buffer`] and [`PacketBuffer::buffer_mut`] actually have the
|
||||
/// appropriate size.
|
||||
fn buffer_mapping() {
|
||||
let mut pb = PacketBuffer::new();
|
||||
|
||||
assert_eq!(pb.buffer().len(), super::PACKET_SIZE);
|
||||
assert_eq!(pb.buffer_mut().len(), super::PACKET_SIZE);
|
||||
}
|
||||
}
|
||||
487
mycelium/src/data.rs
Normal file
487
mycelium/src/data.rs
Normal file
@@ -0,0 +1,487 @@
|
||||
use std::net::{IpAddr, Ipv6Addr};
|
||||
|
||||
use etherparse::{
|
||||
icmpv6::{DestUnreachableCode, TimeExceededCode},
|
||||
Icmpv6Type, PacketBuilder,
|
||||
};
|
||||
use futures::{Sink, SinkExt, Stream, StreamExt};
|
||||
use tokio::sync::mpsc::UnboundedReceiver;
|
||||
use tracing::{debug, error, trace, warn};
|
||||
|
||||
use crate::{crypto::PacketBuffer, metrics::Metrics, packet::DataPacket, router::Router};
|
||||
|
||||
/// Current version of the user data header.
|
||||
const USER_DATA_VERSION: u8 = 1;
|
||||
|
||||
/// Type value indicating L3 data in the user data header.
|
||||
const USER_DATA_L3_TYPE: u8 = 0;
|
||||
|
||||
/// Type value indicating a user message in the data header.
|
||||
const USER_DATA_MESSAGE_TYPE: u8 = 1;
|
||||
|
||||
/// Type value indicating an ICMP packet not returned as regular IPv6 traffic. This is needed when
|
||||
/// intermediate nodes send back icmp data, as the original data is encrypted.
|
||||
const USER_DATA_OOB_ICMP: u8 = 2;
|
||||
|
||||
/// Minimum size in bytes of an IPv6 header.
|
||||
const IPV6_MIN_HEADER_SIZE: usize = 40;
|
||||
|
||||
/// Size of an ICMPv6 header.
|
||||
const ICMP6_HEADER_SIZE: usize = 8;
|
||||
|
||||
/// Minimum MTU for IPV6 according to https://www.rfc-editor.org/rfc/rfc8200#section-5.
|
||||
/// For ICMP, the packet must not be greater than this value. This is specified in
|
||||
/// https://datatracker.ietf.org/doc/html/rfc4443#section-2.4, section (c).
|
||||
const MIN_IPV6_MTU: usize = 1280;
|
||||
|
||||
/// Mask applied to the first byte of an IP header to extract the version.
|
||||
const IP_VERSION_MASK: u8 = 0b1111_0000;
|
||||
|
||||
/// Version byte of an IP header indicating IPv6. Since the version is only 4 bits, the lower bits
|
||||
/// must be masked first.
|
||||
const IPV6_VERSION_BYTE: u8 = 0b0110_0000;
|
||||
|
||||
/// Default hop limit for message packets. For now this is set to 64 hops.
|
||||
///
|
||||
/// For regular l3 packets, we copy the hop limit from the packet itself. We can't do that here, so
|
||||
/// 64 is used as sane default.
|
||||
const MESSAGE_HOP_LIMIT: u8 = 64;
|
||||
|
||||
/// The DataPlane manages forwarding/receiving of local data packets to the [`Router`], and the
|
||||
/// encryption/decryption of them.
|
||||
///
|
||||
/// DataPlane itself can be cloned, but this is not cheap on the router and should be avoided.
|
||||
pub struct DataPlane<M> {
|
||||
router: Router<M>,
|
||||
}
|
||||
|
||||
impl<M> DataPlane<M>
|
||||
where
|
||||
M: Metrics + Clone + Send + 'static,
|
||||
{
|
||||
/// Create a new `DataPlane` using the given [`Router`] for packet handling.
|
||||
///
|
||||
/// `l3_packet_stream` is a stream of l3 packets from the host, usually read from a TUN interface.
|
||||
/// `l3_packet_sink` is a sink for l3 packets received from a romte, usually send to a TUN interface,
|
||||
pub fn new<S, T, U>(
|
||||
router: Router<M>,
|
||||
l3_packet_stream: S,
|
||||
l3_packet_sink: T,
|
||||
message_packet_sink: U,
|
||||
host_packet_source: UnboundedReceiver<DataPacket>,
|
||||
) -> Self
|
||||
where
|
||||
S: Stream<Item = Result<PacketBuffer, std::io::Error>> + Send + Unpin + 'static,
|
||||
T: Sink<PacketBuffer> + Clone + Send + Unpin + 'static,
|
||||
T::Error: std::fmt::Display,
|
||||
U: Sink<(PacketBuffer, IpAddr, IpAddr)> + Send + Unpin + 'static,
|
||||
U::Error: std::fmt::Display,
|
||||
{
|
||||
let dp = Self { router };
|
||||
|
||||
tokio::spawn(
|
||||
dp.clone()
|
||||
.inject_l3_packet_loop(l3_packet_stream, l3_packet_sink.clone()),
|
||||
);
|
||||
tokio::spawn(dp.clone().extract_packet_loop(
|
||||
l3_packet_sink,
|
||||
message_packet_sink,
|
||||
host_packet_source,
|
||||
));
|
||||
|
||||
dp
|
||||
}
|
||||
|
||||
/// Get a reference to the [`Router`] used.
|
||||
pub fn router(&self) -> &Router<M> {
|
||||
&self.router
|
||||
}
|
||||
|
||||
async fn inject_l3_packet_loop<S, T>(self, mut l3_packet_stream: S, mut l3_packet_sink: T)
|
||||
where
|
||||
// TODO: no result
|
||||
// TODO: should IP extraction be handled higher up?
|
||||
S: Stream<Item = Result<PacketBuffer, std::io::Error>> + Send + Unpin + 'static,
|
||||
T: Sink<PacketBuffer> + Clone + Send + Unpin + 'static,
|
||||
T::Error: std::fmt::Display,
|
||||
{
|
||||
let node_subnet = self.router.node_tun_subnet();
|
||||
|
||||
while let Some(packet) = l3_packet_stream.next().await {
|
||||
let mut packet = match packet {
|
||||
Err(e) => {
|
||||
error!("Failed to read packet from TUN interface {e}");
|
||||
continue;
|
||||
}
|
||||
Ok(packet) => packet,
|
||||
};
|
||||
|
||||
trace!("Received packet from tun");
|
||||
|
||||
// Parse an IPv6 header. We don't care about the full header in reality. What we want
|
||||
// to know is:
|
||||
// - This is an IPv6 header
|
||||
// - Hop limit
|
||||
// - Source address
|
||||
// - Destination address
|
||||
// This translates to the following requirements:
|
||||
// - at least 40 bytes of data, as that is the minimum size of an IPv6 header
|
||||
// - first 4 bits (version) are the constant 6 (0b0110)
|
||||
// - src is byte 9-24 (8-23 0 indexed).
|
||||
// - dst is byte 25-40 (24-39 0 indexed).
|
||||
|
||||
if packet.len() < IPV6_MIN_HEADER_SIZE {
|
||||
trace!("Packet can't contain an IPv6 header");
|
||||
continue;
|
||||
}
|
||||
|
||||
if packet[0] & IP_VERSION_MASK != IPV6_VERSION_BYTE {
|
||||
trace!("Packet is not IPv6");
|
||||
continue;
|
||||
}
|
||||
|
||||
let hop_limit = u8::from_be_bytes([packet[7]]);
|
||||
|
||||
let src_ip = Ipv6Addr::from(
|
||||
<&[u8] as TryInto<[u8; 16]>>::try_into(&packet[8..24])
|
||||
.expect("Static range bounds on slice are correct length"),
|
||||
);
|
||||
let dst_ip = Ipv6Addr::from(
|
||||
<&[u8] as TryInto<[u8; 16]>>::try_into(&packet[24..40])
|
||||
.expect("Static range bounds on slice are correct length"),
|
||||
);
|
||||
|
||||
// If this is a packet for our own Subnet, it means there is no local configuration for
|
||||
// the destination ip or /64 subnet, and the IP is unreachable
|
||||
if node_subnet.contains_ip(dst_ip.into()) {
|
||||
trace!(
|
||||
"Replying to local packet for unexisting address: {}",
|
||||
dst_ip
|
||||
);
|
||||
|
||||
let mut icmp_packet = PacketBuffer::new();
|
||||
let host = self.router.node_public_key().address().octets();
|
||||
let icmp = PacketBuilder::ipv6(host, src_ip.octets(), 64).icmpv6(
|
||||
Icmpv6Type::DestinationUnreachable(DestUnreachableCode::Address),
|
||||
);
|
||||
icmp_packet.set_size(icmp.size(packet.len().min(1280 - 48)));
|
||||
let mut writer = &mut icmp_packet.buffer_mut()[..];
|
||||
if let Err(e) = icmp.write(&mut writer, &packet[..packet.len().min(1280 - 48)]) {
|
||||
error!("Failed to construct ICMP packet: {e}");
|
||||
continue;
|
||||
}
|
||||
if let Err(e) = l3_packet_sink.send(icmp_packet).await {
|
||||
error!("Failed to send ICMP packet to host: {e}");
|
||||
}
|
||||
continue;
|
||||
}
|
||||
|
||||
trace!("Received packet from TUN with dest addr: {:?}", dst_ip);
|
||||
// Check if the source address is part of 400::/7
|
||||
let first_src_byte = src_ip.segments()[0] >> 8;
|
||||
if !(0x04..0x06).contains(&first_src_byte) {
|
||||
let mut icmp_packet = PacketBuffer::new();
|
||||
let host = self.router.node_public_key().address().octets();
|
||||
let icmp = PacketBuilder::ipv6(host, src_ip.octets(), 64).icmpv6(
|
||||
Icmpv6Type::DestinationUnreachable(
|
||||
DestUnreachableCode::SourceAddressFailedPolicy,
|
||||
),
|
||||
);
|
||||
icmp_packet.set_size(icmp.size(packet.len().min(1280 - 48)));
|
||||
let mut writer = &mut icmp_packet.buffer_mut()[..];
|
||||
if let Err(e) = icmp.write(&mut writer, &packet[..packet.len().min(1280 - 48)]) {
|
||||
error!("Failed to construct ICMP packet: {e}");
|
||||
continue;
|
||||
}
|
||||
if let Err(e) = l3_packet_sink.send(icmp_packet).await {
|
||||
error!("Failed to send ICMP packet to host: {e}");
|
||||
}
|
||||
continue;
|
||||
}
|
||||
|
||||
// No need to verify destination address, if it is not part of the global subnet there
|
||||
// should not be a route for it, and therefore the route step will generate the
|
||||
// appropriate ICMP.
|
||||
|
||||
let mut header = packet.header_mut();
|
||||
header[0] = USER_DATA_VERSION;
|
||||
header[1] = USER_DATA_L3_TYPE;
|
||||
|
||||
if let Some(icmp) = self.encrypt_and_route_packet(src_ip, dst_ip, hop_limit, packet) {
|
||||
if let Err(e) = l3_packet_sink.send(icmp).await {
|
||||
error!("Could not forward icmp packet back to TUN interface {e}");
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
warn!("Data inject loop from host to router ended");
|
||||
}
|
||||
|
||||
/// Inject a new packet where the content is a `message` fragment.
|
||||
pub fn inject_message_packet(
|
||||
&self,
|
||||
src_ip: Ipv6Addr,
|
||||
dst_ip: Ipv6Addr,
|
||||
mut packet: PacketBuffer,
|
||||
) {
|
||||
let mut header = packet.header_mut();
|
||||
header[0] = USER_DATA_VERSION;
|
||||
header[1] = USER_DATA_MESSAGE_TYPE;
|
||||
|
||||
self.encrypt_and_route_packet(src_ip, dst_ip, MESSAGE_HOP_LIMIT, packet);
|
||||
}
|
||||
|
||||
/// Encrypt the content of a packet based on the destination key, and then inject the packet
|
||||
/// into the [`Router`] for processing.
|
||||
///
|
||||
/// If no key exists for the destination, the content can'be encrypted, the packet is not injected
|
||||
/// into the router, and a packet is returned containing an ICMP packet. Note that a return
|
||||
/// value of [`Option::None`] does not mean the packet was successfully forwarded;
|
||||
fn encrypt_and_route_packet(
|
||||
&self,
|
||||
src_ip: Ipv6Addr,
|
||||
dst_ip: Ipv6Addr,
|
||||
hop_limit: u8,
|
||||
packet: PacketBuffer,
|
||||
) -> Option<PacketBuffer> {
|
||||
// If the packet only has a TTL of 1, we won't be able to route it to the destination
|
||||
// regardless, so just reply with an unencrypted TTL exceeded ICMP.
|
||||
if hop_limit < 2 {
|
||||
debug!(
|
||||
packet.ttl = hop_limit,
|
||||
packet.src = %src_ip,
|
||||
packet.dst = %dst_ip,
|
||||
"Attempting to route packet with insufficient TTL",
|
||||
);
|
||||
|
||||
let mut pb = PacketBuffer::new();
|
||||
// From self to self
|
||||
let icmp = PacketBuilder::ipv6(src_ip.octets(), src_ip.octets(), hop_limit)
|
||||
.icmpv6(Icmpv6Type::TimeExceeded(TimeExceededCode::HopLimitExceeded));
|
||||
// Scale to max size if needed
|
||||
let orig_buf_end = packet
|
||||
.buffer()
|
||||
.len()
|
||||
.min(MIN_IPV6_MTU - IPV6_MIN_HEADER_SIZE - ICMP6_HEADER_SIZE);
|
||||
pb.set_size(icmp.size(orig_buf_end));
|
||||
let mut b = pb.buffer_mut();
|
||||
if let Err(e) = icmp.write(&mut b, &packet.buffer()[..orig_buf_end]) {
|
||||
error!("Failed to construct time exceeded ICMP packet {e}");
|
||||
return None;
|
||||
}
|
||||
|
||||
return Some(pb);
|
||||
}
|
||||
|
||||
// Get shared secret from node and dest address
|
||||
let shared_secret = match self.router.get_shared_secret_if_selected(dst_ip.into()) {
|
||||
Some(ss) => ss,
|
||||
// If we don't have a route to the destination subnet, reply with ICMP no route to
|
||||
// host. Do this here as well to avoid encrypting the ICMP to ourselves.
|
||||
None => {
|
||||
debug!(
|
||||
packet.src = %src_ip,
|
||||
packet.dst = %dst_ip,
|
||||
"No entry found for destination address, dropping packet",
|
||||
);
|
||||
|
||||
let mut pb = PacketBuffer::new();
|
||||
// From self to self
|
||||
let icmp = PacketBuilder::ipv6(src_ip.octets(), src_ip.octets(), hop_limit).icmpv6(
|
||||
Icmpv6Type::DestinationUnreachable(DestUnreachableCode::NoRoute),
|
||||
);
|
||||
// Scale to max size if needed
|
||||
let orig_buf_end = packet
|
||||
.buffer()
|
||||
.len()
|
||||
.min(MIN_IPV6_MTU - IPV6_MIN_HEADER_SIZE - ICMP6_HEADER_SIZE);
|
||||
pb.set_size(icmp.size(orig_buf_end));
|
||||
let mut b = pb.buffer_mut();
|
||||
if let Err(e) = icmp.write(&mut b, &packet.buffer()[..orig_buf_end]) {
|
||||
error!("Failed to construct no route to host ICMP packet {e}");
|
||||
return None;
|
||||
}
|
||||
|
||||
return Some(pb);
|
||||
}
|
||||
};
|
||||
|
||||
self.router.route_packet(DataPacket {
|
||||
dst_ip,
|
||||
src_ip,
|
||||
hop_limit,
|
||||
raw_data: shared_secret.encrypt(packet),
|
||||
});
|
||||
|
||||
None
|
||||
}
|
||||
|
||||
async fn extract_packet_loop<T, U>(
|
||||
self,
|
||||
mut l3_packet_sink: T,
|
||||
mut message_packet_sink: U,
|
||||
mut host_packet_source: UnboundedReceiver<DataPacket>,
|
||||
) where
|
||||
T: Sink<PacketBuffer> + Send + Unpin + 'static,
|
||||
T::Error: std::fmt::Display,
|
||||
U: Sink<(PacketBuffer, IpAddr, IpAddr)> + Send + Unpin + 'static,
|
||||
U::Error: std::fmt::Display,
|
||||
{
|
||||
while let Some(data_packet) = host_packet_source.recv().await {
|
||||
// decrypt & send to TUN interface
|
||||
let shared_secret = if let Some(ss) = self
|
||||
.router
|
||||
.get_shared_secret_from_dest(data_packet.src_ip.into())
|
||||
{
|
||||
ss
|
||||
} else {
|
||||
trace!("Received packet from unknown sender");
|
||||
continue;
|
||||
};
|
||||
let mut decrypted_packet = match shared_secret.decrypt(data_packet.raw_data) {
|
||||
Ok(data) => data,
|
||||
Err(_) => {
|
||||
debug!("Dropping data packet with invalid encrypted content");
|
||||
continue;
|
||||
}
|
||||
};
|
||||
|
||||
// Check header
|
||||
let header = decrypted_packet.header();
|
||||
if header[0] != USER_DATA_VERSION {
|
||||
trace!("Dropping decrypted packet with unknown header version");
|
||||
continue;
|
||||
}
|
||||
|
||||
// Route based on packet type.
|
||||
match header[1] {
|
||||
USER_DATA_L3_TYPE => {
|
||||
let real_packet = decrypted_packet.buffer_mut();
|
||||
if real_packet.len() < IPV6_MIN_HEADER_SIZE {
|
||||
debug!(
|
||||
"Decrypted packet is too short, can't possibly be a valid IPv6 packet"
|
||||
);
|
||||
continue;
|
||||
}
|
||||
// Adjust the hop limit in the decrypted packet to the new value.
|
||||
real_packet[7] = data_packet.hop_limit;
|
||||
if let Err(e) = l3_packet_sink.send(decrypted_packet).await {
|
||||
error!("Failed to send packet on local TUN interface: {e}",);
|
||||
continue;
|
||||
}
|
||||
}
|
||||
USER_DATA_MESSAGE_TYPE => {
|
||||
if let Err(e) = message_packet_sink
|
||||
.send((
|
||||
decrypted_packet,
|
||||
IpAddr::V6(data_packet.src_ip),
|
||||
IpAddr::V6(data_packet.dst_ip),
|
||||
))
|
||||
.await
|
||||
{
|
||||
error!("Failed to send packet to message handler: {e}",);
|
||||
continue;
|
||||
}
|
||||
}
|
||||
USER_DATA_OOB_ICMP => {
|
||||
let real_packet = &*decrypted_packet;
|
||||
if real_packet.len() < IPV6_MIN_HEADER_SIZE + ICMP6_HEADER_SIZE + 16 {
|
||||
debug!(
|
||||
"Decrypted packet is too short, can't possibly be a valid IPv6 ICMP packet"
|
||||
);
|
||||
continue;
|
||||
}
|
||||
if real_packet.len() > MIN_IPV6_MTU + 16 {
|
||||
debug!("Discarding ICMP packet which is too large");
|
||||
continue;
|
||||
}
|
||||
|
||||
let dec_ip = Ipv6Addr::from(
|
||||
<&[u8] as TryInto<[u8; 16]>>::try_into(&real_packet[..16]).unwrap(),
|
||||
);
|
||||
trace!("ICMP for original target {dec_ip}");
|
||||
|
||||
let key =
|
||||
if let Some(key) = self.router.get_shared_secret_from_dest(dec_ip.into()) {
|
||||
key
|
||||
} else {
|
||||
debug!("Can't decrypt OOB ICMP packet from unknown host");
|
||||
continue;
|
||||
};
|
||||
|
||||
let (_, body) = match etherparse::IpHeaders::from_slice(&real_packet[16..]) {
|
||||
Ok(r) => r,
|
||||
Err(e) => {
|
||||
// This is a node which does not adhere to the protocol of sending back
|
||||
// ICMP like this, or it is intentionally sending mallicious packets.
|
||||
debug!(
|
||||
"Dropping malformed OOB ICMP packet from {} for {e}",
|
||||
data_packet.src_ip
|
||||
);
|
||||
continue;
|
||||
}
|
||||
};
|
||||
let (header, body) = match etherparse::Icmpv6Header::from_slice(body.payload) {
|
||||
Ok(r) => r,
|
||||
Err(e) => {
|
||||
// This is a node which does not adhere to the protocol of sending back
|
||||
// ICMP like this, or it is intentionally sending mallicious packets.
|
||||
debug!(
|
||||
"Dropping OOB ICMP packet from {} with malformed ICMP header ({e})",
|
||||
data_packet.src_ip
|
||||
);
|
||||
continue;
|
||||
}
|
||||
};
|
||||
|
||||
// Where are the leftover bytes coming from
|
||||
let orig_pb = match key.decrypt(body[..body.len()].to_vec()) {
|
||||
Ok(pb) => pb,
|
||||
Err(e) => {
|
||||
warn!("Failed to decrypt ICMP data body {e}");
|
||||
continue;
|
||||
}
|
||||
};
|
||||
|
||||
let packet = etherparse::PacketBuilder::ipv6(
|
||||
data_packet.src_ip.octets(),
|
||||
data_packet.dst_ip.octets(),
|
||||
data_packet.hop_limit,
|
||||
)
|
||||
.icmpv6(header.icmp_type);
|
||||
|
||||
let serialized_icmp = packet.size(orig_pb.len());
|
||||
let mut rp = PacketBuffer::new();
|
||||
rp.set_size(serialized_icmp);
|
||||
if let Err(e) =
|
||||
packet.write(&mut (&mut rp.buffer_mut()[..serialized_icmp]), &orig_pb)
|
||||
{
|
||||
error!("Could not reconstruct icmp packet {e}");
|
||||
continue;
|
||||
}
|
||||
if let Err(e) = l3_packet_sink.send(rp).await {
|
||||
error!("Failed to send packet on local TUN interface: {e}",);
|
||||
continue;
|
||||
}
|
||||
}
|
||||
_ => {
|
||||
trace!("Dropping decrypted packet with unknown protocol type");
|
||||
continue;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
warn!("Extract loop from router to host ended");
|
||||
}
|
||||
}
|
||||
|
||||
impl<M> Clone for DataPlane<M>
|
||||
where
|
||||
M: Clone,
|
||||
{
|
||||
fn clone(&self) -> Self {
|
||||
Self {
|
||||
router: self.router.clone(),
|
||||
}
|
||||
}
|
||||
}
|
||||
116
mycelium/src/endpoint.rs
Normal file
116
mycelium/src/endpoint.rs
Normal file
@@ -0,0 +1,116 @@
|
||||
use std::{
|
||||
fmt,
|
||||
net::{AddrParseError, SocketAddr},
|
||||
str::FromStr,
|
||||
};
|
||||
|
||||
use serde::{Deserialize, Serialize};
|
||||
|
||||
#[derive(Debug, Clone, PartialEq, Eq)]
|
||||
/// Error generated while processing improperly formatted endpoints.
|
||||
pub enum EndpointParseError {
|
||||
/// An address was specified without leading protocol information.
|
||||
MissingProtocol,
|
||||
/// An endpoint was specified using a protocol we (currently) do not understand.
|
||||
UnknownProtocol,
|
||||
/// Error while parsing the specific address.
|
||||
Address(AddrParseError),
|
||||
}
|
||||
|
||||
/// Protocol used by an endpoint.
|
||||
#[derive(Debug, Clone, Copy, PartialEq, Eq, PartialOrd, Ord, Hash, Serialize, Deserialize)]
|
||||
#[serde(rename_all = "camelCase")]
|
||||
pub enum Protocol {
|
||||
/// Standard plain text Tcp.
|
||||
Tcp,
|
||||
/// Tls 1.3 with PSK over Tcp.
|
||||
Tls,
|
||||
/// Quic protocol (over UDP).
|
||||
Quic,
|
||||
}
|
||||
|
||||
/// An endpoint defines a address and a protocol to use when communicating with it.
|
||||
#[derive(Debug, Clone, Copy, PartialEq, Eq, PartialOrd, Ord, Hash, Serialize, Deserialize)]
|
||||
#[serde(rename_all = "camelCase")]
|
||||
pub struct Endpoint {
|
||||
proto: Protocol,
|
||||
socket_addr: SocketAddr,
|
||||
}
|
||||
|
||||
impl Endpoint {
|
||||
/// Create a new `Endpoint` with given [`Protocol`] and address.
|
||||
pub fn new(proto: Protocol, socket_addr: SocketAddr) -> Self {
|
||||
Self { proto, socket_addr }
|
||||
}
|
||||
|
||||
/// Get the [`Protocol`] used by this `Endpoint`.
|
||||
pub fn proto(&self) -> Protocol {
|
||||
self.proto
|
||||
}
|
||||
|
||||
/// Get the [`SocketAddr`] used by this `Endpoint`.
|
||||
pub fn address(&self) -> SocketAddr {
|
||||
self.socket_addr
|
||||
}
|
||||
}
|
||||
|
||||
impl FromStr for Endpoint {
|
||||
type Err = EndpointParseError;
|
||||
|
||||
fn from_str(s: &str) -> Result<Self, Self::Err> {
|
||||
match s.split_once("://") {
|
||||
None => Err(EndpointParseError::MissingProtocol),
|
||||
Some((proto, socket)) => {
|
||||
let proto = match proto.to_lowercase().as_str() {
|
||||
"tcp" => Protocol::Tcp,
|
||||
"quic" => Protocol::Quic,
|
||||
"tls" => Protocol::Tls,
|
||||
_ => return Err(EndpointParseError::UnknownProtocol),
|
||||
};
|
||||
let socket_addr = SocketAddr::from_str(socket)?;
|
||||
Ok(Endpoint { proto, socket_addr })
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl fmt::Display for Endpoint {
|
||||
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
|
||||
f.write_fmt(format_args!("{} {}", self.proto, self.socket_addr))
|
||||
}
|
||||
}
|
||||
|
||||
impl fmt::Display for Protocol {
|
||||
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
|
||||
f.write_str(match self {
|
||||
Self::Tcp => "Tcp",
|
||||
Self::Tls => "Tls",
|
||||
Self::Quic => "Quic",
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
impl fmt::Display for EndpointParseError {
|
||||
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
|
||||
match self {
|
||||
Self::MissingProtocol => f.write_str("missing leading protocol identifier"),
|
||||
Self::UnknownProtocol => f.write_str("protocol for endpoint is not supported"),
|
||||
Self::Address(e) => f.write_fmt(format_args!("failed to parse address: {e}")),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl std::error::Error for EndpointParseError {
|
||||
fn source(&self) -> Option<&(dyn std::error::Error + 'static)> {
|
||||
match self {
|
||||
Self::Address(e) => Some(e),
|
||||
_ => None,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl From<AddrParseError> for EndpointParseError {
|
||||
fn from(value: AddrParseError) -> Self {
|
||||
Self::Address(value)
|
||||
}
|
||||
}
|
||||
53
mycelium/src/filters.rs
Normal file
53
mycelium/src/filters.rs
Normal file
@@ -0,0 +1,53 @@
|
||||
use crate::{babel, subnet::Subnet};
|
||||
|
||||
/// This trait is used to filter incoming updates from peers. Only updates which pass all
|
||||
/// configured filters on the local [`Router`](crate::router::Router) will actually be forwarded
|
||||
/// to the [`Router`](crate::router::Router) for processing.
|
||||
pub trait RouteUpdateFilter {
|
||||
/// Judge an incoming update.
|
||||
fn allow(&self, update: &babel::Update) -> bool;
|
||||
}
|
||||
|
||||
/// Limit the subnet size of subnets announced in updates to be at most `N` bits. Note that "at
|
||||
/// most" here means that the actual prefix length needs to be **AT LEAST** this value.
|
||||
pub struct MaxSubnetSize<const N: u8>;
|
||||
|
||||
impl<const N: u8> RouteUpdateFilter for MaxSubnetSize<N> {
|
||||
fn allow(&self, update: &babel::Update) -> bool {
|
||||
update.subnet().prefix_len() >= N
|
||||
}
|
||||
}
|
||||
|
||||
/// Limit the subnet announced to be included in the given subnet.
|
||||
pub struct AllowedSubnet {
|
||||
subnet: Subnet,
|
||||
}
|
||||
|
||||
impl AllowedSubnet {
|
||||
/// Create a new `AllowedSubnet` filter, which only allows updates who's `Subnet` is contained
|
||||
/// in the given `Subnet`.
|
||||
pub fn new(subnet: Subnet) -> Self {
|
||||
Self { subnet }
|
||||
}
|
||||
}
|
||||
|
||||
impl RouteUpdateFilter for AllowedSubnet {
|
||||
fn allow(&self, update: &babel::Update) -> bool {
|
||||
self.subnet.contains_subnet(&update.subnet())
|
||||
}
|
||||
}
|
||||
|
||||
/// Limit the announced subnets to those which contain the derived IP from the `RouterId`.
|
||||
///
|
||||
/// Since retractions can be sent by any node to indicate they don't have a route for the subnet,
|
||||
/// these are also allowed.
|
||||
pub struct RouterIdOwnsSubnet;
|
||||
|
||||
impl RouteUpdateFilter for RouterIdOwnsSubnet {
|
||||
fn allow(&self, update: &babel::Update) -> bool {
|
||||
update.metric().is_infinite()
|
||||
|| update
|
||||
.subnet()
|
||||
.contains_ip(update.router_id().to_pubkey().address().into())
|
||||
}
|
||||
}
|
||||
38
mycelium/src/interval.rs
Normal file
38
mycelium/src/interval.rs
Normal file
@@ -0,0 +1,38 @@
|
||||
//! Dedicated logic for
|
||||
//! [intervals](https://datatracker.ietf.org/doc/html/rfc8966#name-solving-starvation-sequenci).
|
||||
|
||||
use std::time::Duration;
|
||||
|
||||
/// An interval in the babel protocol.
|
||||
///
|
||||
/// Intervals represent a duration, and are expressed in centiseconds (0.01 second / 10
|
||||
/// milliseconds). `Interval` implements [`From`] [`u16`] to create a new interval from a raw
|
||||
/// value, and [`From`] [`Duration`] to create a new `Interval` from an existing [`Duration`].
|
||||
/// There are also implementation to convert back to the aforementioned types. Note that in case of
|
||||
/// duration, millisecond precision is lost.
|
||||
#[derive(Debug, Clone)]
|
||||
pub struct Interval(u16);
|
||||
|
||||
impl From<Duration> for Interval {
|
||||
fn from(value: Duration) -> Self {
|
||||
Interval((value.as_millis() / 10) as u16)
|
||||
}
|
||||
}
|
||||
|
||||
impl From<Interval> for Duration {
|
||||
fn from(value: Interval) -> Self {
|
||||
Duration::from_millis(value.0 as u64 * 10)
|
||||
}
|
||||
}
|
||||
|
||||
impl From<u16> for Interval {
|
||||
fn from(value: u16) -> Self {
|
||||
Interval(value)
|
||||
}
|
||||
}
|
||||
|
||||
impl From<Interval> for u16 {
|
||||
fn from(value: Interval) -> Self {
|
||||
value.0
|
||||
}
|
||||
}
|
||||
462
mycelium/src/lib.rs
Normal file
462
mycelium/src/lib.rs
Normal file
@@ -0,0 +1,462 @@
|
||||
use std::net::{IpAddr, Ipv6Addr};
|
||||
use std::path::PathBuf;
|
||||
#[cfg(feature = "message")]
|
||||
use std::{future::Future, time::Duration};
|
||||
|
||||
use crate::cdn::Cdn;
|
||||
use crate::tun::TunConfig;
|
||||
use bytes::BytesMut;
|
||||
use data::DataPlane;
|
||||
use endpoint::Endpoint;
|
||||
#[cfg(feature = "message")]
|
||||
use message::TopicConfig;
|
||||
#[cfg(feature = "message")]
|
||||
use message::{
|
||||
MessageId, MessageInfo, MessagePushResponse, MessageStack, PushMessageError, ReceivedMessage,
|
||||
};
|
||||
use metrics::Metrics;
|
||||
use peer_manager::{PeerExists, PeerNotFound, PeerStats, PrivateNetworkKey};
|
||||
use routing_table::{NoRouteSubnet, QueriedSubnet, RouteEntry};
|
||||
use subnet::Subnet;
|
||||
use tokio::net::TcpListener;
|
||||
use tracing::{error, info, warn};
|
||||
|
||||
mod babel;
|
||||
pub mod cdn;
|
||||
mod connection;
|
||||
pub mod crypto;
|
||||
pub mod data;
|
||||
pub mod endpoint;
|
||||
pub mod filters;
|
||||
mod interval;
|
||||
#[cfg(feature = "message")]
|
||||
pub mod message;
|
||||
mod metric;
|
||||
pub mod metrics;
|
||||
pub mod packet;
|
||||
mod peer;
|
||||
pub mod peer_manager;
|
||||
pub mod router;
|
||||
mod router_id;
|
||||
mod routing_table;
|
||||
mod rr_cache;
|
||||
mod seqno_cache;
|
||||
mod sequence_number;
|
||||
mod source_table;
|
||||
pub mod subnet;
|
||||
pub mod task;
|
||||
mod tun;
|
||||
|
||||
/// The prefix of the global subnet used.
|
||||
pub const GLOBAL_SUBNET_ADDRESS: IpAddr = IpAddr::V6(Ipv6Addr::new(0x400, 0, 0, 0, 0, 0, 0, 0));
|
||||
/// The prefix length of the global subnet used.
|
||||
pub const GLOBAL_SUBNET_PREFIX_LEN: u8 = 7;
|
||||
|
||||
/// Config for a mycelium [`Node`].
|
||||
pub struct Config<M> {
|
||||
/// The secret key of the node.
|
||||
pub node_key: crypto::SecretKey,
|
||||
/// Statically configured peers.
|
||||
pub peers: Vec<Endpoint>,
|
||||
/// Tun interface should be disabled.
|
||||
pub no_tun: bool,
|
||||
/// Listen port for TCP connections.
|
||||
pub tcp_listen_port: u16,
|
||||
/// Listen port for Quic connections.
|
||||
pub quic_listen_port: Option<u16>,
|
||||
/// Udp port for peer discovery.
|
||||
pub peer_discovery_port: Option<u16>,
|
||||
/// Name for the TUN device.
|
||||
#[cfg(any(
|
||||
target_os = "linux",
|
||||
all(target_os = "macos", not(feature = "mactunfd")),
|
||||
target_os = "windows"
|
||||
))]
|
||||
pub tun_name: String,
|
||||
|
||||
/// Configuration for a private network, if run in that mode. To enable private networking,
|
||||
/// this must be a name + a PSK.
|
||||
pub private_network_config: Option<(String, PrivateNetworkKey)>,
|
||||
/// Implementation of the `Metrics` trait, used to expose information about the system
|
||||
/// internals.
|
||||
pub metrics: M,
|
||||
/// Mark that's set on all packets that we send on the underlying network
|
||||
pub firewall_mark: Option<u32>,
|
||||
|
||||
// tun_fd is android, iOS, macos on appstore specific option
|
||||
// We can't create TUN device from the Rust code in android, iOS, and macos on appstore.
|
||||
// So, we create the TUN device on Kotlin(android) or Swift(iOS, macos) then pass
|
||||
// the TUN's file descriptor to mycelium.
|
||||
#[cfg(any(
|
||||
target_os = "android",
|
||||
target_os = "ios",
|
||||
all(target_os = "macos", feature = "mactunfd"),
|
||||
))]
|
||||
pub tun_fd: Option<i32>,
|
||||
|
||||
/// The maount of worker tasks spawned to process updates. Up to this amound of updates can be
|
||||
/// processed in parallel. Because processing an update is a CPU bound task, it is pointless to
|
||||
/// set this to a value which is higher than the amount of logical CPU cores available to the
|
||||
/// system.
|
||||
pub update_workers: usize,
|
||||
|
||||
pub cdn_cache: Option<PathBuf>,
|
||||
|
||||
/// Configuration for message topics, if this is not set the default config will be used.
|
||||
#[cfg(feature = "message")]
|
||||
pub topic_config: Option<TopicConfig>,
|
||||
}
|
||||
|
||||
/// The Node is the main structure in mycelium. It governs the entire data flow.
|
||||
pub struct Node<M> {
|
||||
router: router::Router<M>,
|
||||
peer_manager: peer_manager::PeerManager<M>,
|
||||
_cdn: Option<Cdn>,
|
||||
#[cfg(feature = "message")]
|
||||
message_stack: message::MessageStack<M>,
|
||||
}
|
||||
|
||||
/// General info about a node.
|
||||
pub struct NodeInfo {
|
||||
/// The overlay subnet in use by the node.
|
||||
pub node_subnet: Subnet,
|
||||
/// The public key of the node
|
||||
pub node_pubkey: crypto::PublicKey,
|
||||
}
|
||||
|
||||
impl<M> Node<M>
|
||||
where
|
||||
M: Metrics + Clone + Send + Sync + 'static,
|
||||
{
|
||||
/// Setup a new `Node` with the provided [`Config`].
|
||||
pub async fn new(config: Config<M>) -> Result<Self, Box<dyn std::error::Error>> {
|
||||
// If a private network is configured, validate network name
|
||||
if let Some((net_name, _)) = &config.private_network_config {
|
||||
if net_name.len() < 2 || net_name.len() > 64 {
|
||||
return Err(std::io::Error::new(
|
||||
std::io::ErrorKind::InvalidInput,
|
||||
"network name must be between 2 and 64 characters",
|
||||
)
|
||||
.into());
|
||||
}
|
||||
}
|
||||
let node_pub_key = crypto::PublicKey::from(&config.node_key);
|
||||
let node_addr = node_pub_key.address();
|
||||
let (tun_tx, tun_rx) = tokio::sync::mpsc::unbounded_channel();
|
||||
|
||||
let node_subnet = Subnet::new(
|
||||
// Truncate last 64 bits of address.
|
||||
// TODO: find a better way to do this.
|
||||
Subnet::new(node_addr.into(), 64)
|
||||
.expect("64 is a valid IPv6 prefix size; qed")
|
||||
.network(),
|
||||
64,
|
||||
)
|
||||
.expect("64 is a valid IPv6 prefix size; qed");
|
||||
|
||||
// Creating a new Router instance
|
||||
let router = match router::Router::new(
|
||||
config.update_workers,
|
||||
tun_tx,
|
||||
node_subnet,
|
||||
vec![node_subnet],
|
||||
(config.node_key, node_pub_key),
|
||||
vec![
|
||||
Box::new(filters::AllowedSubnet::new(
|
||||
Subnet::new(GLOBAL_SUBNET_ADDRESS, GLOBAL_SUBNET_PREFIX_LEN)
|
||||
.expect("Global subnet is properly defined; qed"),
|
||||
)),
|
||||
Box::new(filters::MaxSubnetSize::<64>),
|
||||
Box::new(filters::RouterIdOwnsSubnet),
|
||||
],
|
||||
config.metrics.clone(),
|
||||
) {
|
||||
Ok(router) => {
|
||||
info!(
|
||||
"Router created. Pubkey: {:x}",
|
||||
BytesMut::from(&router.node_public_key().as_bytes()[..])
|
||||
);
|
||||
router
|
||||
}
|
||||
Err(e) => {
|
||||
error!("Error creating router: {e}");
|
||||
panic!("Error creating router: {e}");
|
||||
}
|
||||
};
|
||||
|
||||
// Creating a new PeerManager instance
|
||||
let pm = peer_manager::PeerManager::new(
|
||||
router.clone(),
|
||||
config.peers,
|
||||
config.tcp_listen_port,
|
||||
config.quic_listen_port,
|
||||
config.peer_discovery_port.unwrap_or_default(),
|
||||
config.peer_discovery_port.is_none(),
|
||||
config.private_network_config,
|
||||
config.metrics,
|
||||
config.firewall_mark,
|
||||
)?;
|
||||
info!("Started peer manager");
|
||||
|
||||
#[cfg(feature = "message")]
|
||||
let (tx, rx) = tokio::sync::mpsc::channel(100);
|
||||
#[cfg(feature = "message")]
|
||||
let msg_receiver = tokio_stream::wrappers::ReceiverStream::new(rx);
|
||||
#[cfg(feature = "message")]
|
||||
let msg_sender = tokio_util::sync::PollSender::new(tx);
|
||||
#[cfg(not(feature = "message"))]
|
||||
let msg_sender = futures::sink::drain();
|
||||
|
||||
let _data_plane = if config.no_tun {
|
||||
warn!("Starting data plane without TUN interface, L3 functionality disabled");
|
||||
DataPlane::new(
|
||||
router.clone(),
|
||||
// No tun so create a dummy stream for L3 packets which never yields
|
||||
tokio_stream::pending(),
|
||||
// Similarly, create a sink which just discards every packet we would receive
|
||||
futures::sink::drain(),
|
||||
msg_sender,
|
||||
tun_rx,
|
||||
)
|
||||
} else {
|
||||
#[cfg(not(any(
|
||||
target_os = "linux",
|
||||
target_os = "macos",
|
||||
target_os = "windows",
|
||||
target_os = "android",
|
||||
target_os = "ios"
|
||||
)))]
|
||||
{
|
||||
panic!("On this platform, you can only run with --no-tun");
|
||||
}
|
||||
#[cfg(any(
|
||||
target_os = "linux",
|
||||
target_os = "macos",
|
||||
target_os = "windows",
|
||||
target_os = "android",
|
||||
target_os = "ios"
|
||||
))]
|
||||
{
|
||||
#[cfg(any(
|
||||
target_os = "linux",
|
||||
all(target_os = "macos", not(feature = "mactunfd")),
|
||||
target_os = "windows"
|
||||
))]
|
||||
let tun_config = TunConfig {
|
||||
name: config.tun_name.clone(),
|
||||
node_subnet: Subnet::new(node_addr.into(), 64)
|
||||
.expect("64 is a valid subnet size for IPv6; qed"),
|
||||
route_subnet: Subnet::new(GLOBAL_SUBNET_ADDRESS, GLOBAL_SUBNET_PREFIX_LEN)
|
||||
.expect("Static configured TUN route is valid; qed"),
|
||||
};
|
||||
#[cfg(any(
|
||||
target_os = "android",
|
||||
target_os = "ios",
|
||||
all(target_os = "macos", feature = "mactunfd"),
|
||||
))]
|
||||
let tun_config = TunConfig {
|
||||
tun_fd: config.tun_fd.unwrap(),
|
||||
};
|
||||
|
||||
let (rxhalf, txhalf) = tun::new(tun_config).await?;
|
||||
|
||||
info!("Node overlay IP: {node_addr}");
|
||||
DataPlane::new(router.clone(), rxhalf, txhalf, msg_sender, tun_rx)
|
||||
}
|
||||
};
|
||||
|
||||
let cdn = config.cdn_cache.map(Cdn::new);
|
||||
if let Some(ref cdn) = cdn {
|
||||
let listener = TcpListener::bind("localhost:80").await?;
|
||||
cdn.start(listener)?;
|
||||
}
|
||||
|
||||
#[cfg(feature = "message")]
|
||||
let ms = MessageStack::new(_data_plane, msg_receiver, config.topic_config);
|
||||
|
||||
Ok(Node {
|
||||
router,
|
||||
peer_manager: pm,
|
||||
_cdn: cdn,
|
||||
#[cfg(feature = "message")]
|
||||
message_stack: ms,
|
||||
})
|
||||
}
|
||||
|
||||
/// Get information about the running `Node`
|
||||
pub fn info(&self) -> NodeInfo {
|
||||
NodeInfo {
|
||||
node_subnet: self.router.node_tun_subnet(),
|
||||
node_pubkey: self.router.node_public_key(),
|
||||
}
|
||||
}
|
||||
|
||||
/// Get information about the current peers in the `Node`
|
||||
pub fn peer_info(&self) -> Vec<PeerStats> {
|
||||
self.peer_manager.peers()
|
||||
}
|
||||
|
||||
/// Add a new peer to the system identified by an [`Endpoint`].
|
||||
pub fn add_peer(&self, endpoint: Endpoint) -> Result<(), PeerExists> {
|
||||
self.peer_manager.add_peer(endpoint)
|
||||
}
|
||||
|
||||
/// Remove an existing peer identified by an [`Endpoint`] from the system.
|
||||
pub fn remove_peer(&self, endpoint: Endpoint) -> Result<(), PeerNotFound> {
|
||||
self.peer_manager.delete_peer(&endpoint)
|
||||
}
|
||||
|
||||
/// List all selected [`routes`](RouteEntry) in the system.
|
||||
pub fn selected_routes(&self) -> Vec<RouteEntry> {
|
||||
self.router.load_selected_routes()
|
||||
}
|
||||
|
||||
/// List all fallback [`routes`](RouteEntry) in the system.
|
||||
pub fn fallback_routes(&self) -> Vec<RouteEntry> {
|
||||
self.router.load_fallback_routes()
|
||||
}
|
||||
|
||||
/// List all [`queried subnets`](QueriedSubnet) in the system.
|
||||
pub fn queried_subnets(&self) -> Vec<QueriedSubnet> {
|
||||
self.router.load_queried_subnets()
|
||||
}
|
||||
|
||||
/// List all [`subnets with no route`](NoRouteSubnet) in the system.
|
||||
pub fn no_route_entries(&self) -> Vec<NoRouteSubnet> {
|
||||
self.router.load_no_route_entries()
|
||||
}
|
||||
|
||||
/// Get public key from the IP of `Node`
|
||||
pub fn get_pubkey_from_ip(&self, ip: IpAddr) -> Option<crypto::PublicKey> {
|
||||
self.router.get_pubkey(ip)
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg(feature = "message")]
|
||||
impl<M> Node<M>
|
||||
where
|
||||
M: Metrics + Clone + Send + 'static,
|
||||
{
|
||||
/// Wait for a messsage to arrive in the message stack.
|
||||
///
|
||||
/// An the optional `topic` is provided, only messages which have exactly the same value in
|
||||
/// `topic` will be returned. The `pop` argument decides if the message is removed from the
|
||||
/// internal queue or not. If `pop` is `false`, the same message will be returned on the next
|
||||
/// call (with the same topic).
|
||||
///
|
||||
/// This method returns a future which will wait indefinitely until a message is received. It
|
||||
/// is generally a good idea to put a limit on how long to wait by wrapping this in a [`tokio::time::timeout`].
|
||||
pub fn get_message(
|
||||
&self,
|
||||
pop: bool,
|
||||
topic: Option<Vec<u8>>,
|
||||
) -> impl Future<Output = ReceivedMessage> + '_ {
|
||||
// First reborrow only the message stack from self, then manually construct a future. This
|
||||
// avoids a lifetime issue on the router, which is not sync. If a regular 'async' fn would
|
||||
// be used here, we can't specify that at this point sadly.
|
||||
let ms = &self.message_stack;
|
||||
async move { ms.message(pop, topic).await }
|
||||
}
|
||||
|
||||
/// Push a new message to the message stack.
|
||||
///
|
||||
/// The system will attempt to transmit the message for `try_duration`. A message is considered
|
||||
/// transmitted when the receiver has indicated it completely received the message. If
|
||||
/// `subscribe_reply` is `true`, the second return value will be [`Option::Some`], with a
|
||||
/// watcher which will resolve if a reply for this exact message comes in. Since this relies on
|
||||
/// the receiver actually sending a reply, ther is no guarantee that this will eventually
|
||||
/// resolve.
|
||||
pub fn push_message(
|
||||
&self,
|
||||
dst: IpAddr,
|
||||
data: Vec<u8>,
|
||||
topic: Option<Vec<u8>>,
|
||||
try_duration: Duration,
|
||||
subscribe_reply: bool,
|
||||
) -> Result<MessagePushResponse, PushMessageError> {
|
||||
self.message_stack.new_message(
|
||||
dst,
|
||||
data,
|
||||
topic.unwrap_or_default(),
|
||||
try_duration,
|
||||
subscribe_reply,
|
||||
)
|
||||
}
|
||||
|
||||
/// Get the status of a message sent previously.
|
||||
///
|
||||
/// Returns [`Option::None`] if no message is found with the given id. Message info is only
|
||||
/// retained for a limited time after a message has been received, or after the message has
|
||||
/// been aborted due to a timeout.
|
||||
pub fn message_status(&self, id: MessageId) -> Option<MessageInfo> {
|
||||
self.message_stack.message_info(id)
|
||||
}
|
||||
|
||||
/// Send a reply to a previously received message.
|
||||
pub fn reply_message(
|
||||
&self,
|
||||
id: MessageId,
|
||||
dst: IpAddr,
|
||||
data: Vec<u8>,
|
||||
try_duration: Duration,
|
||||
) -> MessageId {
|
||||
self.message_stack
|
||||
.reply_message(id, dst, data, try_duration)
|
||||
}
|
||||
|
||||
/// Get a list of all configured topics
|
||||
pub fn topics(&self) -> Vec<Vec<u8>> {
|
||||
self.message_stack.topics()
|
||||
}
|
||||
|
||||
pub fn topic_allowed_sources(&self, topic: &Vec<u8>) -> Option<Vec<Subnet>> {
|
||||
self.message_stack.topic_allowed_sources(topic)
|
||||
}
|
||||
|
||||
/// Sets the default topic action to accept or reject. This decides how topics which don't have
|
||||
/// an explicit whitelist get handled.
|
||||
pub fn accept_unconfigured_topic(&self, accept: bool) {
|
||||
self.message_stack.set_default_topic_action(accept)
|
||||
}
|
||||
|
||||
/// Whether a topic without default configuration is accepted or not.
|
||||
pub fn unconfigure_topic_action(&self) -> bool {
|
||||
self.message_stack.get_default_topic_action()
|
||||
}
|
||||
|
||||
/// Add a topic to the whitelist without any configured allowed sources.
|
||||
pub fn add_topic_whitelist(&self, topic: Vec<u8>) {
|
||||
self.message_stack.add_topic_whitelist(topic)
|
||||
}
|
||||
|
||||
/// Remove a topic from the whitelist. Future messages will follow the default action.
|
||||
pub fn remove_topic_whitelist(&self, topic: Vec<u8>) {
|
||||
self.message_stack.remove_topic_whitelist(topic)
|
||||
}
|
||||
|
||||
/// Add a new whitelisted source for a topic. This creates the topic if it does not exist yet.
|
||||
pub fn add_topic_whitelist_src(&self, topic: Vec<u8>, src: Subnet) {
|
||||
self.message_stack.add_topic_whitelist_src(topic, src)
|
||||
}
|
||||
|
||||
/// Remove a whitelisted source for a topic.
|
||||
pub fn remove_topic_whitelist_src(&self, topic: Vec<u8>, src: Subnet) {
|
||||
self.message_stack.remove_topic_whitelist_src(topic, src)
|
||||
}
|
||||
|
||||
/// Set the forward socket for a topic. Creates the topic if it doesn't exist.
|
||||
pub fn set_topic_forward_socket(&self, topic: Vec<u8>, socket_path: std::path::PathBuf) {
|
||||
self.message_stack
|
||||
.set_topic_forward_socket(topic, Some(socket_path))
|
||||
}
|
||||
|
||||
/// Get the forward socket for a topic, if any.
|
||||
pub fn get_topic_forward_socket(&self, topic: &Vec<u8>) -> Option<std::path::PathBuf> {
|
||||
self.message_stack.get_topic_forward_socket(topic)
|
||||
}
|
||||
|
||||
/// Removes the forward socket for the topic, if one exists
|
||||
pub fn delete_topic_forward_socket(&self, topic: Vec<u8>) {
|
||||
self.message_stack.set_topic_forward_socket(topic, None)
|
||||
}
|
||||
}
|
||||
1867
mycelium/src/message.rs
Normal file
1867
mycelium/src/message.rs
Normal file
File diff suppressed because it is too large
Load Diff
254
mycelium/src/message/chunk.rs
Normal file
254
mycelium/src/message/chunk.rs
Normal file
@@ -0,0 +1,254 @@
|
||||
use std::fmt;
|
||||
|
||||
use super::MessagePacket;
|
||||
|
||||
/// A message representing a "chunk" message.
|
||||
///
|
||||
/// The body of a chunk message has the following structure:
|
||||
/// - 8 bytes: chunk index
|
||||
/// - 8 bytes: chunk offset
|
||||
/// - 8 bytes: chunk size
|
||||
/// - remainder: chunk data of length based on field 3
|
||||
pub struct MessageChunk {
|
||||
buffer: MessagePacket,
|
||||
}
|
||||
|
||||
impl MessageChunk {
|
||||
/// Create a new `MessageChunk` in the provided [`MessagePacket`].
|
||||
pub fn new(mut buffer: MessagePacket) -> Self {
|
||||
buffer.set_used_buffer_size(24);
|
||||
buffer.header_mut().flags_mut().set_chunk();
|
||||
Self { buffer }
|
||||
}
|
||||
|
||||
/// Return the index of the chunk in the message, as written in the body.
|
||||
pub fn chunk_idx(&self) -> u64 {
|
||||
u64::from_be_bytes(
|
||||
self.buffer.buffer()[..8]
|
||||
.try_into()
|
||||
.expect("Buffer contains a size field of valid length; qed"),
|
||||
)
|
||||
}
|
||||
|
||||
/// Set the index of the chunk in the message body.
|
||||
pub fn set_chunk_idx(&mut self, chunk_idx: u64) {
|
||||
self.buffer.buffer_mut()[..8].copy_from_slice(&chunk_idx.to_be_bytes())
|
||||
}
|
||||
|
||||
/// Return the chunk offset in the message, as written in the body.
|
||||
pub fn chunk_offset(&self) -> u64 {
|
||||
u64::from_be_bytes(
|
||||
self.buffer.buffer()[8..16]
|
||||
.try_into()
|
||||
.expect("Buffer contains a size field of valid length; qed"),
|
||||
)
|
||||
}
|
||||
|
||||
/// Set the offset of the chunk in the message body.
|
||||
pub fn set_chunk_offset(&mut self, chunk_offset: u64) {
|
||||
self.buffer.buffer_mut()[8..16].copy_from_slice(&chunk_offset.to_be_bytes())
|
||||
}
|
||||
|
||||
/// Return the size of the chunk in the message, as written in the body.
|
||||
pub fn chunk_size(&self) -> u64 {
|
||||
// Shield against a corrupt value.
|
||||
u64::min(
|
||||
u64::from_be_bytes(
|
||||
self.buffer.buffer()[16..24]
|
||||
.try_into()
|
||||
.expect("Buffer contains a size field of valid length; qed"),
|
||||
),
|
||||
self.buffer.buffer().len() as u64 - 24,
|
||||
)
|
||||
}
|
||||
|
||||
/// Set the size of the chunk in the message body.
|
||||
pub fn set_chunk_size(&mut self, chunk_size: u64) {
|
||||
self.buffer.buffer_mut()[16..24].copy_from_slice(&chunk_size.to_be_bytes())
|
||||
}
|
||||
|
||||
/// Return a reference to the chunk data in the message.
|
||||
pub fn data(&self) -> &[u8] {
|
||||
&self.buffer.buffer()[24..24 + self.chunk_size() as usize]
|
||||
}
|
||||
|
||||
/// Set the chunk data in this message. This will also set the size field to the proper value.
|
||||
pub fn set_chunk_data(&mut self, data: &[u8]) -> Result<(), InsufficientChunkSpace> {
|
||||
let buf = self.buffer.buffer_mut();
|
||||
let available_space = buf.len() - 24;
|
||||
|
||||
if data.len() > available_space {
|
||||
return Err(InsufficientChunkSpace {
|
||||
available: available_space,
|
||||
needed: data.len(),
|
||||
});
|
||||
}
|
||||
|
||||
// Slicing based on data.len() is fine here as we just checked to make sure we can handle
|
||||
// this capacity.
|
||||
buf[24..24 + data.len()].copy_from_slice(data);
|
||||
|
||||
self.set_chunk_size(data.len() as u64);
|
||||
// Also set the extra space used by the buffer on the underlying packet.
|
||||
self.buffer.set_used_buffer_size(24 + data.len());
|
||||
|
||||
Ok(())
|
||||
}
|
||||
|
||||
/// Convert the `MessageChunk` into a reply. This does nothing if it is already a reply.
|
||||
pub fn into_reply(mut self) -> Self {
|
||||
self.buffer.header_mut().flags_mut().set_ack();
|
||||
// We want to leave the length field in tact but don't want to copy the data in the reply.
|
||||
// This needs additional work on the underlying buffer.
|
||||
// TODO
|
||||
self
|
||||
}
|
||||
|
||||
/// Consumes this `MessageChunk`, returning the underlying [`MessagePacket`].
|
||||
pub fn into_inner(self) -> MessagePacket {
|
||||
self.buffer
|
||||
}
|
||||
}
|
||||
|
||||
/// An error indicating not enough space is availbe in a message to set the chunk data.
|
||||
#[derive(Debug)]
|
||||
pub struct InsufficientChunkSpace {
|
||||
/// Amount of space available in the chunk.
|
||||
pub available: usize,
|
||||
/// Amount of space needed to set the chunk data
|
||||
pub needed: usize,
|
||||
}
|
||||
|
||||
impl fmt::Display for InsufficientChunkSpace {
|
||||
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
|
||||
write!(
|
||||
f,
|
||||
"Insufficient capacity available, needed {} bytes, have {} bytes",
|
||||
self.needed, self.available
|
||||
)
|
||||
}
|
||||
}
|
||||
|
||||
impl std::error::Error for InsufficientChunkSpace {}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use std::array;
|
||||
|
||||
use crate::{crypto::PacketBuffer, message::MessagePacket};
|
||||
|
||||
use super::MessageChunk;
|
||||
|
||||
#[test]
|
||||
fn chunk_flag_set() {
|
||||
let mc = MessageChunk::new(MessagePacket::new(PacketBuffer::new()));
|
||||
|
||||
let mp = mc.into_inner();
|
||||
assert!(mp.header().flags().chunk());
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn read_chunk_idx() {
|
||||
let mut pb = PacketBuffer::new();
|
||||
pb.buffer_mut()[12..20].copy_from_slice(&[0, 0, 0, 0, 0, 0, 100, 73]);
|
||||
|
||||
let ms = MessageChunk::new(MessagePacket::new(pb));
|
||||
|
||||
assert_eq!(ms.chunk_idx(), 25_673);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn write_chunk_idx() {
|
||||
let mut ms = MessageChunk::new(MessagePacket::new(PacketBuffer::new()));
|
||||
|
||||
ms.set_chunk_idx(723);
|
||||
|
||||
// Since we don't work with packet buffer we don't have to account for the message packet
|
||||
// header.
|
||||
assert_eq!(&ms.buffer.buffer()[..8], &[0, 0, 0, 0, 0, 0, 2, 211]);
|
||||
assert_eq!(ms.chunk_idx(), 723);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn read_chunk_offset() {
|
||||
let mut pb = PacketBuffer::new();
|
||||
pb.buffer_mut()[20..28].copy_from_slice(&[0, 0, 0, 0, 0, 20, 40, 60]);
|
||||
|
||||
let ms = MessageChunk::new(MessagePacket::new(pb));
|
||||
|
||||
assert_eq!(ms.chunk_offset(), 1_321_020);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn write_chunk_offset() {
|
||||
let mut ms = MessageChunk::new(MessagePacket::new(PacketBuffer::new()));
|
||||
|
||||
ms.set_chunk_offset(1_000_000);
|
||||
|
||||
// Since we don't work with packet buffer we don't have to account for the message packet
|
||||
// header.
|
||||
assert_eq!(&ms.buffer.buffer()[8..16], &[0, 0, 0, 0, 0, 15, 66, 64]);
|
||||
assert_eq!(ms.chunk_offset(), 1_000_000);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn read_chunk_size() {
|
||||
let mut pb = PacketBuffer::new();
|
||||
pb.buffer_mut()[28..36].copy_from_slice(&[0, 0, 0, 0, 0, 0, 3, 232]);
|
||||
|
||||
let ms = MessageChunk::new(MessagePacket::new(pb));
|
||||
|
||||
assert_eq!(ms.chunk_size(), 1_000);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn write_chunk_size() {
|
||||
let mut ms = MessageChunk::new(MessagePacket::new(PacketBuffer::new()));
|
||||
|
||||
ms.set_chunk_size(1_300);
|
||||
|
||||
// Since we don't work with packet buffer we don't have to account for the message packet
|
||||
// header.
|
||||
assert_eq!(&ms.buffer.buffer()[16..24], &[0, 0, 0, 0, 0, 0, 5, 20]);
|
||||
assert_eq!(ms.chunk_size(), 1_300);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn read_chunk_data() {
|
||||
const CHUNK_DATA: &[u8] = &[1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16];
|
||||
let mut pb = PacketBuffer::new();
|
||||
// Set data len
|
||||
pb.buffer_mut()[28..36].copy_from_slice(&CHUNK_DATA.len().to_be_bytes());
|
||||
pb.buffer_mut()[36..36 + CHUNK_DATA.len()].copy_from_slice(CHUNK_DATA);
|
||||
|
||||
let ms = MessageChunk::new(MessagePacket::new(pb));
|
||||
|
||||
assert_eq!(ms.chunk_size(), 16);
|
||||
assert_eq!(ms.data(), CHUNK_DATA);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn write_chunk_data() {
|
||||
const CHUNK_DATA: &[u8] = &[1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16];
|
||||
let mut ms = MessageChunk::new(MessagePacket::new(PacketBuffer::new()));
|
||||
|
||||
let res = ms.set_chunk_data(CHUNK_DATA);
|
||||
assert!(res.is_ok());
|
||||
|
||||
// Since we don't work with packet buffer we don't have to account for the message packet
|
||||
// header.
|
||||
// Check and make sure size is properly set.
|
||||
assert_eq!(&ms.buffer.buffer()[16..24], &[0, 0, 0, 0, 0, 0, 0, 16]);
|
||||
assert_eq!(ms.chunk_size(), 16);
|
||||
assert_eq!(ms.data(), CHUNK_DATA);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn write_chunk_data_oversized() {
|
||||
let data: [u8; 1500] = array::from_fn(|_| 0xFF);
|
||||
let mut ms = MessageChunk::new(MessagePacket::new(PacketBuffer::new()));
|
||||
|
||||
let res = ms.set_chunk_data(&data);
|
||||
assert!(res.is_err());
|
||||
}
|
||||
}
|
||||
131
mycelium/src/message/done.rs
Normal file
131
mycelium/src/message/done.rs
Normal file
@@ -0,0 +1,131 @@
|
||||
use super::{MessageChecksum, MessagePacket, MESSAGE_CHECKSUM_LENGTH};
|
||||
|
||||
/// A message representing a "done" message.
|
||||
///
|
||||
/// The body of a done message has the following structure:
|
||||
/// - 8 bytes: chunks transmitted
|
||||
/// - 32 bytes: checksum of the transmitted data
|
||||
pub struct MessageDone {
|
||||
buffer: MessagePacket,
|
||||
}
|
||||
|
||||
impl MessageDone {
|
||||
/// Create a new `MessageDone` in the provided [`MessagePacket`].
|
||||
pub fn new(mut buffer: MessagePacket) -> Self {
|
||||
buffer.set_used_buffer_size(40);
|
||||
buffer.header_mut().flags_mut().set_done();
|
||||
Self { buffer }
|
||||
}
|
||||
|
||||
/// Return the amount of chunks in the message, as written in the body.
|
||||
pub fn chunk_count(&self) -> u64 {
|
||||
u64::from_be_bytes(
|
||||
self.buffer.buffer()[..8]
|
||||
.try_into()
|
||||
.expect("Buffer contains a size field of valid length; qed"),
|
||||
)
|
||||
}
|
||||
|
||||
/// Set the amount of chunks field of the message body.
|
||||
pub fn set_chunk_count(&mut self, chunk_count: u64) {
|
||||
self.buffer.buffer_mut()[..8].copy_from_slice(&chunk_count.to_be_bytes())
|
||||
}
|
||||
|
||||
/// Get the checksum of the message from the body.
|
||||
pub fn checksum(&self) -> MessageChecksum {
|
||||
MessageChecksum::from_bytes(
|
||||
self.buffer.buffer()[8..8 + MESSAGE_CHECKSUM_LENGTH]
|
||||
.try_into()
|
||||
.expect("Buffer contains enough data for a checksum; qed"),
|
||||
)
|
||||
}
|
||||
|
||||
/// Set the checksum of the message in the body.
|
||||
pub fn set_checksum(&mut self, checksum: MessageChecksum) {
|
||||
self.buffer.buffer_mut()[8..8 + MESSAGE_CHECKSUM_LENGTH]
|
||||
.copy_from_slice(checksum.as_bytes())
|
||||
}
|
||||
|
||||
/// Convert the `MessageDone` into a reply. This does nothing if it is already a reply.
|
||||
pub fn into_reply(mut self) -> Self {
|
||||
self.buffer.header_mut().flags_mut().set_ack();
|
||||
self
|
||||
}
|
||||
|
||||
/// Consumes this `MessageDone`, returning the underlying [`MessagePacket`].
|
||||
pub fn into_inner(self) -> MessagePacket {
|
||||
self.buffer
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use crate::{
|
||||
crypto::PacketBuffer,
|
||||
message::{MessageChecksum, MessagePacket},
|
||||
};
|
||||
|
||||
use super::MessageDone;
|
||||
|
||||
#[test]
|
||||
fn done_flag_set() {
|
||||
let md = MessageDone::new(MessagePacket::new(PacketBuffer::new()));
|
||||
|
||||
let mp = md.into_inner();
|
||||
assert!(mp.header().flags().done());
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn read_chunk_count() {
|
||||
let mut pb = PacketBuffer::new();
|
||||
pb.buffer_mut()[12..20].copy_from_slice(&[0, 0, 0, 0, 0, 0, 73, 55]);
|
||||
|
||||
let ms = MessageDone::new(MessagePacket::new(pb));
|
||||
|
||||
assert_eq!(ms.chunk_count(), 18_743);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn write_chunk_count() {
|
||||
let mut ms = MessageDone::new(MessagePacket::new(PacketBuffer::new()));
|
||||
|
||||
ms.set_chunk_count(10_000);
|
||||
|
||||
// Since we don't work with packet buffer we don't have to account for the message packet
|
||||
// header.
|
||||
assert_eq!(&ms.buffer.buffer()[..8], &[0, 0, 0, 0, 0, 0, 39, 16]);
|
||||
assert_eq!(ms.chunk_count(), 10_000);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn read_checksum() {
|
||||
const CHECKSUM: MessageChecksum = MessageChecksum::from_bytes([
|
||||
0x00, 0x01, 0x02, 0x03, 0x04, 0x05, 0x06, 0x07, 0x08, 0x09, 0x0A, 0x0B, 0x0C, 0x0D,
|
||||
0x0E, 0x0F, 0x10, 0x11, 0x12, 0x13, 0x14, 0x15, 0x16, 0x17, 0x18, 0x19, 0x1A, 0x1B,
|
||||
0x1C, 0x1D, 0x1E, 0x1F,
|
||||
]);
|
||||
let mut pb = PacketBuffer::new();
|
||||
pb.buffer_mut()[20..52].copy_from_slice(CHECKSUM.as_bytes());
|
||||
|
||||
let ms = MessageDone::new(MessagePacket::new(pb));
|
||||
|
||||
assert_eq!(ms.checksum(), CHECKSUM);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn write_checksum() {
|
||||
const CHECKSUM: MessageChecksum = MessageChecksum::from_bytes([
|
||||
0x00, 0x01, 0x02, 0x03, 0x04, 0x05, 0x06, 0x07, 0x08, 0x09, 0x0A, 0x0B, 0x0C, 0x0D,
|
||||
0x0E, 0x0F, 0x10, 0x11, 0x12, 0x13, 0x14, 0x15, 0x16, 0x17, 0x18, 0x19, 0x1A, 0x1B,
|
||||
0x1C, 0x1D, 0x1E, 0x1F,
|
||||
]);
|
||||
let mut ms = MessageDone::new(MessagePacket::new(PacketBuffer::new()));
|
||||
|
||||
ms.set_checksum(CHECKSUM);
|
||||
|
||||
// Since we don't work with packet buffer we don't have to account for the message packet
|
||||
// header.
|
||||
assert_eq!(&ms.buffer.buffer()[8..40], CHECKSUM.as_bytes());
|
||||
assert_eq!(ms.checksum(), CHECKSUM);
|
||||
}
|
||||
}
|
||||
101
mycelium/src/message/init.rs
Normal file
101
mycelium/src/message/init.rs
Normal file
@@ -0,0 +1,101 @@
|
||||
use super::MessagePacket;
|
||||
|
||||
/// A message representing an init message.
|
||||
///
|
||||
/// The body of an init message has the following structure:
|
||||
/// - 8 bytes size
|
||||
pub struct MessageInit {
|
||||
buffer: MessagePacket,
|
||||
}
|
||||
|
||||
impl MessageInit {
|
||||
/// Create a new `MessageInit` in the provided [`MessagePacket`].
|
||||
pub fn new(mut buffer: MessagePacket) -> Self {
|
||||
buffer.set_used_buffer_size(9);
|
||||
buffer.header_mut().flags_mut().set_init();
|
||||
Self { buffer }
|
||||
}
|
||||
|
||||
/// Return the length of the message, as written in the body.
|
||||
pub fn length(&self) -> u64 {
|
||||
u64::from_be_bytes(
|
||||
self.buffer.buffer()[..8]
|
||||
.try_into()
|
||||
.expect("Buffer contains a size field of valid length; qed"),
|
||||
)
|
||||
}
|
||||
|
||||
/// Return the topic of the message, as written in the body.
|
||||
pub fn topic(&self) -> &[u8] {
|
||||
let topic_len = self.buffer.buffer()[8] as usize;
|
||||
&self.buffer.buffer()[9..9 + topic_len]
|
||||
}
|
||||
|
||||
/// Set the length field of the message body.
|
||||
pub fn set_length(&mut self, length: u64) {
|
||||
self.buffer.buffer_mut()[..8].copy_from_slice(&length.to_be_bytes())
|
||||
}
|
||||
|
||||
/// Set the topic in the message body.
|
||||
///
|
||||
/// # Panics
|
||||
///
|
||||
/// This function panics if the topic is longer than 255 bytes.
|
||||
pub fn set_topic(&mut self, topic: &[u8]) {
|
||||
assert!(
|
||||
topic.len() <= u8::MAX as usize,
|
||||
"Topic can be 255 bytes long at most"
|
||||
);
|
||||
self.buffer.set_used_buffer_size(9 + topic.len());
|
||||
self.buffer.buffer_mut()[8] = topic.len() as u8;
|
||||
self.buffer.buffer_mut()[9..9 + topic.len()].copy_from_slice(topic);
|
||||
}
|
||||
|
||||
/// Convert the `MessageInit` into a reply. This does nothing if it is already a reply.
|
||||
pub fn into_reply(mut self) -> Self {
|
||||
self.buffer.header_mut().flags_mut().set_ack();
|
||||
self
|
||||
}
|
||||
|
||||
/// Consumes this `MessageInit`, returning the underlying [`MessagePacket`].
|
||||
pub fn into_inner(self) -> MessagePacket {
|
||||
self.buffer
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use crate::{crypto::PacketBuffer, message::MessagePacket};
|
||||
|
||||
use super::MessageInit;
|
||||
|
||||
#[test]
|
||||
fn init_flag_set() {
|
||||
let mi = MessageInit::new(MessagePacket::new(PacketBuffer::new()));
|
||||
|
||||
let mp = mi.into_inner();
|
||||
assert!(mp.header().flags().init());
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn read_length() {
|
||||
let mut pb = PacketBuffer::new();
|
||||
pb.buffer_mut()[12..20].copy_from_slice(&[0, 0, 0, 0, 2, 3, 4, 5]);
|
||||
|
||||
let ms = MessageInit::new(MessagePacket::new(pb));
|
||||
|
||||
assert_eq!(ms.length(), 33_752_069);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn write_length() {
|
||||
let mut ms = MessageInit::new(MessagePacket::new(PacketBuffer::new()));
|
||||
|
||||
ms.set_length(3_432_634_632);
|
||||
|
||||
// Since we don't work with packet buffer we don't have to account for the message packet
|
||||
// header.
|
||||
assert_eq!(&ms.buffer.buffer()[..8], &[0, 0, 0, 0, 204, 153, 217, 8]);
|
||||
assert_eq!(ms.length(), 3_432_634_632);
|
||||
}
|
||||
}
|
||||
230
mycelium/src/message/topic.rs
Normal file
230
mycelium/src/message/topic.rs
Normal file
@@ -0,0 +1,230 @@
|
||||
use crate::subnet::Subnet;
|
||||
use core::fmt;
|
||||
use serde::{
|
||||
de::{Deserialize, Deserializer, MapAccess, Visitor},
|
||||
Deserialize as DeserializeMacro,
|
||||
};
|
||||
use std::collections::HashMap;
|
||||
use std::path::PathBuf;
|
||||
|
||||
/// Configuration for a topic whitelist, including allowed subnets and optional forward socket
|
||||
#[derive(Debug, Default, Clone)]
|
||||
pub struct TopicWhitelistConfig {
|
||||
/// Subnets that are allowed to send messages to this topic
|
||||
subnets: Vec<Subnet>,
|
||||
/// Optional Unix domain socket path to forward messages to
|
||||
forward_socket: Option<PathBuf>,
|
||||
}
|
||||
|
||||
impl TopicWhitelistConfig {
|
||||
/// Create a new empty whitelist config
|
||||
pub fn new() -> Self {
|
||||
Self::default()
|
||||
}
|
||||
|
||||
/// Get the list of whitelisted subnets
|
||||
pub fn subnets(&self) -> &Vec<Subnet> {
|
||||
&self.subnets
|
||||
}
|
||||
|
||||
/// Get the forward socket path, if any
|
||||
pub fn forward_socket(&self) -> Option<&PathBuf> {
|
||||
self.forward_socket.as_ref()
|
||||
}
|
||||
|
||||
/// Set the forward socket path
|
||||
pub fn set_forward_socket(&mut self, path: Option<PathBuf>) {
|
||||
self.forward_socket = path;
|
||||
}
|
||||
|
||||
/// Add a subnet to the whitelist
|
||||
pub fn add_subnet(&mut self, subnet: Subnet) {
|
||||
self.subnets.push(subnet);
|
||||
}
|
||||
|
||||
/// Remove a subnet from the whitelist
|
||||
pub fn remove_subnet(&mut self, subnet: &Subnet) {
|
||||
self.subnets.retain(|s| s != subnet);
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Debug, Default, Clone)]
|
||||
pub struct TopicConfig {
|
||||
/// The default action to to take if no acl is defined for a topic.
|
||||
default: MessageAction,
|
||||
/// Explicitly configured whitelists for topics. Ip's which aren't part of the whitelist will
|
||||
/// not be allowed to send messages to that topic. If a topic is not in this map, the default
|
||||
/// action will be used.
|
||||
whitelist: HashMap<Vec<u8>, TopicWhitelistConfig>,
|
||||
}
|
||||
|
||||
impl TopicConfig {
|
||||
/// Get the [`default action`](MessageAction) if the topic is not configured.
|
||||
pub fn default(&self) -> MessageAction {
|
||||
self.default
|
||||
}
|
||||
|
||||
/// Set the default [`action`](MessageAction) which does not have a whitelist configured.
|
||||
pub fn set_default(&mut self, default: MessageAction) {
|
||||
self.default = default;
|
||||
}
|
||||
|
||||
/// Get the fully configured whitelist
|
||||
pub fn whitelist(&self) -> &HashMap<Vec<u8>, TopicWhitelistConfig> {
|
||||
&self.whitelist
|
||||
}
|
||||
|
||||
/// Insert a new topic in the whitelist, without any configured allowed sources.
|
||||
pub fn add_topic_whitelist(&mut self, topic: Vec<u8>) {
|
||||
self.whitelist.entry(topic).or_default();
|
||||
}
|
||||
|
||||
/// Set the forward socket for a topic. Does nothing if the topic doesn't exist.
|
||||
pub fn set_topic_forward_socket(&mut self, topic: Vec<u8>, socket_path: Option<PathBuf>) {
|
||||
self.whitelist
|
||||
.entry(topic)
|
||||
.and_modify(|c| c.set_forward_socket(socket_path));
|
||||
}
|
||||
|
||||
/// Get the forward socket for a topic, if any.
|
||||
pub fn get_topic_forward_socket(&self, topic: &Vec<u8>) -> Option<&PathBuf> {
|
||||
self.whitelist
|
||||
.get(topic)
|
||||
.and_then(|config| config.forward_socket())
|
||||
}
|
||||
|
||||
/// Remove a topic from the whitelist. Future messages will follow the default action.
|
||||
pub fn remove_topic_whitelist(&mut self, topic: &Vec<u8>) {
|
||||
self.whitelist.remove(topic);
|
||||
}
|
||||
|
||||
/// Adds a new whitelisted source for a topic. This creates the topic if it does not exist yet.
|
||||
pub fn add_topic_whitelist_src(&mut self, topic: Vec<u8>, src: Subnet) {
|
||||
self.whitelist.entry(topic).or_default().add_subnet(src);
|
||||
}
|
||||
|
||||
/// Removes a whitelisted source for a topic.
|
||||
///
|
||||
/// If the last source is removed for a topic, the entry remains, and must be cleared by calling
|
||||
/// [`Self::remove_topic_whitelist`] to fall back to the default action. Note that an empty
|
||||
/// whitelist effectively blocks all messages for a topic.
|
||||
///
|
||||
/// This does nothing if the topic does not exist.
|
||||
pub fn remove_topic_whitelist_src(&mut self, topic: &Vec<u8>, src: Subnet) {
|
||||
if let Some(whitelist_config) = self.whitelist.get_mut(topic) {
|
||||
whitelist_config.remove_subnet(&src);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Debug, Default, Clone, Copy, DeserializeMacro)]
|
||||
pub enum MessageAction {
|
||||
/// Accept the message
|
||||
#[default]
|
||||
Accept,
|
||||
/// Reject the message
|
||||
Reject,
|
||||
}
|
||||
|
||||
// Helper function to parse a subnet from a string
|
||||
fn parse_subnet_str<E>(s: &str) -> Result<Subnet, E>
|
||||
where
|
||||
E: serde::de::Error,
|
||||
{
|
||||
// Try to parse as a subnet (with prefix)
|
||||
if let Ok(ipnet) = s.parse::<ipnet::IpNet>() {
|
||||
return Subnet::new(ipnet.addr(), ipnet.prefix_len())
|
||||
.map_err(|e| serde::de::Error::custom(format!("Invalid subnet prefix length: {e}")));
|
||||
}
|
||||
|
||||
// Try to parse as an IP address (convert to /32 or /128 subnet)
|
||||
if let Ok(ip) = s.parse::<std::net::IpAddr>() {
|
||||
let prefix_len = match ip {
|
||||
std::net::IpAddr::V4(_) => 32,
|
||||
std::net::IpAddr::V6(_) => 128,
|
||||
};
|
||||
return Subnet::new(ip, prefix_len)
|
||||
.map_err(|e| serde::de::Error::custom(format!("Invalid subnet prefix length: {e}")));
|
||||
}
|
||||
|
||||
Err(serde::de::Error::custom(format!(
|
||||
"Invalid subnet or IP address: {s}",
|
||||
)))
|
||||
}
|
||||
|
||||
// Define a struct for deserializing the whitelist config
|
||||
#[derive(DeserializeMacro)]
|
||||
struct WhitelistConfigData {
|
||||
#[serde(default)]
|
||||
subnets: Vec<String>,
|
||||
#[serde(default)]
|
||||
forward_socket: Option<String>,
|
||||
}
|
||||
|
||||
// Add this implementation right after the TopicConfig struct definition
|
||||
impl<'de> Deserialize<'de> for TopicConfig {
|
||||
fn deserialize<D>(deserializer: D) -> Result<Self, D::Error>
|
||||
where
|
||||
D: Deserializer<'de>,
|
||||
{
|
||||
struct TopicConfigVisitor;
|
||||
|
||||
impl<'de> Visitor<'de> for TopicConfigVisitor {
|
||||
type Value = TopicConfig;
|
||||
|
||||
fn expecting(&self, formatter: &mut fmt::Formatter) -> fmt::Result {
|
||||
formatter.write_str("a topic configuration")
|
||||
}
|
||||
|
||||
fn visit_map<V>(self, mut map: V) -> Result<TopicConfig, V::Error>
|
||||
where
|
||||
V: MapAccess<'de>,
|
||||
{
|
||||
let mut default = MessageAction::default();
|
||||
let mut whitelist = HashMap::new();
|
||||
|
||||
while let Some(key) = map.next_key::<String>()? {
|
||||
if key == "default" {
|
||||
default = map.next_value()?;
|
||||
} else {
|
||||
// Try to parse as a WhitelistConfigData first
|
||||
if let Ok(config_data) = map.next_value::<WhitelistConfigData>() {
|
||||
let mut whitelist_config = TopicWhitelistConfig::default();
|
||||
|
||||
// Process subnets
|
||||
for subnet_str in config_data.subnets {
|
||||
let subnet = parse_subnet_str(&subnet_str)?;
|
||||
whitelist_config.add_subnet(subnet);
|
||||
}
|
||||
|
||||
// Process forward_socket
|
||||
if let Some(socket_path) = config_data.forward_socket {
|
||||
whitelist_config
|
||||
.set_forward_socket(Some(PathBuf::from(socket_path)));
|
||||
}
|
||||
|
||||
// Convert string key to Vec<u8>
|
||||
whitelist.insert(key.into_bytes(), whitelist_config);
|
||||
} else {
|
||||
// Fallback to old format: just a list of subnets
|
||||
let subnet_strs = map.next_value::<Vec<String>>()?;
|
||||
let mut whitelist_config = TopicWhitelistConfig::default();
|
||||
|
||||
for subnet_str in subnet_strs {
|
||||
let subnet = parse_subnet_str(&subnet_str)?;
|
||||
whitelist_config.add_subnet(subnet);
|
||||
}
|
||||
|
||||
// Convert string key to Vec<u8>
|
||||
whitelist.insert(key.into_bytes(), whitelist_config);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
Ok(TopicConfig { default, whitelist })
|
||||
}
|
||||
}
|
||||
|
||||
deserializer.deserialize_map(TopicConfigVisitor)
|
||||
}
|
||||
}
|
||||
144
mycelium/src/metric.rs
Normal file
144
mycelium/src/metric.rs
Normal file
@@ -0,0 +1,144 @@
|
||||
//! Dedicated logic for
|
||||
//! [metrics](https://datatracker.ietf.org/doc/html/rfc8966#metric-computation).
|
||||
|
||||
use core::fmt;
|
||||
use std::ops::{Add, Sub};
|
||||
|
||||
/// Value of the infinite metric.
|
||||
const METRIC_INFINITE: u16 = 0xFFFF;
|
||||
|
||||
/// A `Metric` is used to indicate the cost associated with a route. A lower Metric means a route
|
||||
/// is more favorable.
|
||||
#[derive(Debug, Clone, Copy, PartialEq, Eq, Ord, PartialOrd)]
|
||||
pub struct Metric(u16);
|
||||
|
||||
impl Metric {
|
||||
/// Create a new `Metric` with the given value.
|
||||
pub const fn new(value: u16) -> Self {
|
||||
Metric(value)
|
||||
}
|
||||
|
||||
/// Creates a new infinite `Metric`.
|
||||
pub const fn infinite() -> Self {
|
||||
Metric(METRIC_INFINITE)
|
||||
}
|
||||
|
||||
/// Checks if this metric indicates a retracted route.
|
||||
pub const fn is_infinite(&self) -> bool {
|
||||
self.0 == METRIC_INFINITE
|
||||
}
|
||||
|
||||
/// Checks if this metric represents a directly connected route.
|
||||
pub const fn is_direct(&self) -> bool {
|
||||
self.0 == 0
|
||||
}
|
||||
|
||||
/// Computes the absolute value of the difference between this and another `Metric`.
|
||||
pub fn delta(&self, rhs: &Self) -> Metric {
|
||||
Metric(if self > rhs {
|
||||
self.0 - rhs.0
|
||||
} else {
|
||||
rhs.0 - self.0
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
impl fmt::Display for Metric {
|
||||
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
|
||||
if self.is_infinite() {
|
||||
f.pad("Infinite")
|
||||
} else {
|
||||
f.write_fmt(format_args!("{}", self.0))
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl From<u16> for Metric {
|
||||
fn from(value: u16) -> Self {
|
||||
Metric(value)
|
||||
}
|
||||
}
|
||||
|
||||
impl From<Metric> for u16 {
|
||||
fn from(value: Metric) -> Self {
|
||||
value.0
|
||||
}
|
||||
}
|
||||
|
||||
impl Add for Metric {
|
||||
type Output = Self;
|
||||
|
||||
fn add(self, rhs: Metric) -> Self::Output {
|
||||
if self.is_infinite() || rhs.is_infinite() {
|
||||
return Metric::infinite();
|
||||
}
|
||||
Metric(
|
||||
self.0
|
||||
.checked_add(rhs.0)
|
||||
.map(|r| if r == u16::MAX { r - 1 } else { r })
|
||||
.unwrap_or(u16::MAX - 1),
|
||||
)
|
||||
}
|
||||
}
|
||||
|
||||
impl Add<&Metric> for &Metric {
|
||||
type Output = Metric;
|
||||
|
||||
fn add(self, rhs: &Metric) -> Self::Output {
|
||||
if self.is_infinite() || rhs.is_infinite() {
|
||||
return Metric::infinite();
|
||||
}
|
||||
Metric(
|
||||
self.0
|
||||
.checked_add(rhs.0)
|
||||
.map(|r| if r == u16::MAX { r - 1 } else { r })
|
||||
.unwrap_or(u16::MAX - 1),
|
||||
)
|
||||
}
|
||||
}
|
||||
|
||||
impl Add<&Metric> for Metric {
|
||||
type Output = Self;
|
||||
|
||||
fn add(self, rhs: &Metric) -> Self::Output {
|
||||
if self.is_infinite() || rhs.is_infinite() {
|
||||
return Metric::infinite();
|
||||
}
|
||||
Metric(
|
||||
self.0
|
||||
.checked_add(rhs.0)
|
||||
.map(|r| if r == u16::MAX { r - 1 } else { r })
|
||||
.unwrap_or(u16::MAX - 1),
|
||||
)
|
||||
}
|
||||
}
|
||||
|
||||
impl Add<Metric> for &Metric {
|
||||
type Output = Metric;
|
||||
|
||||
fn add(self, rhs: Metric) -> Self::Output {
|
||||
if self.is_infinite() || rhs.is_infinite() {
|
||||
return Metric::infinite();
|
||||
}
|
||||
Metric(
|
||||
self.0
|
||||
.checked_add(rhs.0)
|
||||
.map(|r| if r == u16::MAX { r - 1 } else { r })
|
||||
.unwrap_or(u16::MAX - 1),
|
||||
)
|
||||
}
|
||||
}
|
||||
|
||||
impl Sub<Metric> for Metric {
|
||||
type Output = Metric;
|
||||
|
||||
fn sub(self, rhs: Metric) -> Self::Output {
|
||||
if rhs.is_infinite() {
|
||||
panic!("Can't subtract an infinite metric");
|
||||
}
|
||||
if self.is_infinite() {
|
||||
return Metric::infinite();
|
||||
}
|
||||
Metric(self.0.saturating_sub(rhs.0))
|
||||
}
|
||||
}
|
||||
195
mycelium/src/metrics.rs
Normal file
195
mycelium/src/metrics.rs
Normal file
@@ -0,0 +1,195 @@
|
||||
//! This module is used for collection of runtime metrics of a `mycelium` system. The main item of
|
||||
//! interest is the [`Metrics`] trait. Users can provide their own implementation of this, or use
|
||||
//! the default provided implementation to disable gathering metrics.
|
||||
|
||||
use crate::peer_manager::PeerType;
|
||||
|
||||
/// The collection of all metrics exported by a [`mycelium node`](crate::Node). It is up to the
|
||||
/// user to provide an implementation which implements the methods for metrics they are interested
|
||||
/// in. All methods have a default implementation, so if the user is not interested in any metrics,
|
||||
/// a NOOP handler can be implemented as follows:
|
||||
///
|
||||
/// ```rust
|
||||
/// use mycelium::metrics::Metrics;
|
||||
///
|
||||
/// #[derive(Clone)]
|
||||
/// struct NoMetrics;
|
||||
/// impl Metrics for NoMetrics {}
|
||||
/// ```
|
||||
pub trait Metrics {
|
||||
/// The [`Router`](crate::router::Router) received a new Hello TLV from a peer.
|
||||
#[inline]
|
||||
fn router_process_hello(&self) {}
|
||||
|
||||
/// The [`Router`](crate::router::Router) received a new IHU TLV from a peer.
|
||||
#[inline]
|
||||
fn router_process_ihu(&self) {}
|
||||
|
||||
/// The [`Router`](crate::router::Router) received a new Seqno request TLV from a peer.
|
||||
#[inline]
|
||||
fn router_process_seqno_request(&self) {}
|
||||
|
||||
/// The [`Router`](crate::router::Router) received a new Route request TLV from a peer.
|
||||
/// Additionally, it is recorded if this is a wildcard request (route table dump request)
|
||||
/// or a request for a specific subnet.
|
||||
#[inline]
|
||||
fn router_process_route_request(&self, _wildcard: bool) {}
|
||||
|
||||
/// The [`Router`](crate::router::Router) received a new Update TLV from a peer.
|
||||
#[inline]
|
||||
fn router_process_update(&self) {}
|
||||
|
||||
/// The [`Router`](crate::router::Router) tried to send an update to a peer, but before sending
|
||||
/// it we found out the peer is actually already dead.
|
||||
///
|
||||
/// This can happen, since a peer is a remote entity we have no control over, and it can be
|
||||
/// removed at any time for any reason. However, in normal operation, the amount of times this
|
||||
/// happens should be fairly small compared to the amount of updates we send/receive.
|
||||
#[inline]
|
||||
fn router_update_dead_peer(&self) {}
|
||||
|
||||
/// The amount of TLV's received from peers, to be processed by the
|
||||
/// [`Router`](crate::router::Router).
|
||||
#[inline]
|
||||
fn router_received_tlv(&self) {}
|
||||
|
||||
/// The [`Router`](crate::router::Router) dropped a received TLV before processing it, as the
|
||||
/// peer who sent it has already died in the meantime.
|
||||
#[inline]
|
||||
fn router_tlv_source_died(&self) {}
|
||||
|
||||
/// The [`Router`](crate::router::Router) dropped a received TLV before processing it, because
|
||||
/// it coulnd't keep up
|
||||
#[inline]
|
||||
fn router_tlv_discarded(&self) {}
|
||||
|
||||
/// A [`Peer`](crate::peer::Peer) was added to the [`Router`](crate::router::Router).
|
||||
#[inline]
|
||||
fn router_peer_added(&self) {}
|
||||
|
||||
/// A [`Peer`](crate::peer::Peer) was removed from the [`Router`](crate::router::Router).
|
||||
#[inline]
|
||||
fn router_peer_removed(&self) {}
|
||||
|
||||
/// A [`Peer`](crate::peer::Peer) informed the [`Router`](crate::router::Router) it died, or
|
||||
/// the router otherwise noticed the Peer is dead.
|
||||
#[inline]
|
||||
fn router_peer_died(&self) {}
|
||||
|
||||
/// The [`Router`](crate::router::Router) ran a route selection procedure.
|
||||
#[inline]
|
||||
fn router_route_selection_ran(&self) {}
|
||||
|
||||
/// A [`SourceKey`](crate::source_table::SourceKey) expired and got cleaned up by the [`Router`](crate::router::Router).
|
||||
#[inline]
|
||||
fn router_source_key_expired(&self) {}
|
||||
|
||||
/// A [`RouteKey`](crate::routing_table::RouteKey) expired, and the router either set the
|
||||
/// [`Metric`](crate::metric::Metric) of the route to infinity, or cleaned up the route entry
|
||||
/// altogether.
|
||||
#[inline]
|
||||
fn router_route_key_expired(&self, _removed: bool) {}
|
||||
|
||||
/// A route which expired was actually the selected route for the
|
||||
/// [`Subnet`](crate::subnet::Subnet). Note that [`Self::router_route_key_expired`] will
|
||||
/// also have been called.
|
||||
#[inline]
|
||||
fn router_selected_route_expired(&self) {}
|
||||
|
||||
/// The [`Router`](crate::router::Router) sends a "triggered" update to it's peers.
|
||||
#[inline]
|
||||
fn router_triggered_update(&self) {}
|
||||
|
||||
/// The [`Router`](crate::router::Router) extracted a packet for the local subnet.
|
||||
#[inline]
|
||||
fn router_route_packet_local(&self) {}
|
||||
|
||||
/// The [`Router`](crate::router::Router) forwarded a packet to a peer.
|
||||
#[inline]
|
||||
fn router_route_packet_forward(&self) {}
|
||||
|
||||
/// The [`Router`](crate::router::Router) dropped a packet it was routing because it's TTL
|
||||
/// reached 0.
|
||||
#[inline]
|
||||
fn router_route_packet_ttl_expired(&self) {}
|
||||
|
||||
/// The [`Router`](crate::router::Router) dropped a packet it was routing because there was no
|
||||
/// route for the destination IP.
|
||||
#[inline]
|
||||
fn router_route_packet_no_route(&self) {}
|
||||
|
||||
/// The [`Router`](crate::router::Router) replied to a seqno request with a local route, which
|
||||
/// is more recent (bigger seqno) than the request.
|
||||
#[inline]
|
||||
fn router_seqno_request_reply_local(&self) {}
|
||||
|
||||
/// The [`Router`](crate::router::Router) replied to a seqno request by bumping its own seqno
|
||||
/// and advertising the local route.
|
||||
#[inline]
|
||||
fn router_seqno_request_bump_seqno(&self) {}
|
||||
|
||||
/// The [`Router`](crate::router::Router) dropped a seqno request because the TTL reached 0.
|
||||
#[inline]
|
||||
fn router_seqno_request_dropped_ttl(&self) {}
|
||||
|
||||
/// The [`Router`](crate::router::Router) forwarded a seqno request to a feasible route.
|
||||
#[inline]
|
||||
fn router_seqno_request_forward_feasible(&self) {}
|
||||
|
||||
/// The [`Router`](crate::router::Router) forwarded a seqno request to a (potentially)
|
||||
/// unfeasible route.
|
||||
#[inline]
|
||||
fn router_seqno_request_forward_unfeasible(&self) {}
|
||||
|
||||
/// The [`Router`](crate::router::Router) dropped a seqno request becase none of the other
|
||||
/// handling methods applied.
|
||||
#[inline]
|
||||
fn router_seqno_request_unhandled(&self) {}
|
||||
|
||||
/// The [`time`](std::time::Duration) used by the [`Router`](crate::router::Router) to handle a
|
||||
/// control packet.
|
||||
#[inline]
|
||||
fn router_time_spent_handling_tlv(&self, _duration: std::time::Duration, _tlv_type: &str) {}
|
||||
|
||||
/// The [`time`](std::time::Duration) used by the [`Router`](crate::router::Router) to
|
||||
/// periodically propagate selected routes to peers.
|
||||
#[inline]
|
||||
fn router_time_spent_periodic_propagating_selected_routes(
|
||||
&self,
|
||||
_duration: std::time::Duration,
|
||||
) {
|
||||
}
|
||||
|
||||
/// An update was processed and accepted by the router, but did not run route selection.
|
||||
#[inline]
|
||||
fn router_update_skipped_route_selection(&self) {}
|
||||
|
||||
/// An update was denied by a configured filter.
|
||||
#[inline]
|
||||
fn router_update_denied_by_filter(&self) {}
|
||||
|
||||
/// An update was accepted by the router filters, but was otherwise unfeasible or a retraction,
|
||||
/// for an unknown subnet.
|
||||
#[inline]
|
||||
fn router_update_not_interested(&self) {}
|
||||
|
||||
/// A new [`Peer`](crate::peer::Peer) was added to the
|
||||
/// [`PeerManager`](crate::peer_manager::PeerManager) while it is running.
|
||||
#[inline]
|
||||
fn peer_manager_peer_added(&self, _pt: PeerType) {}
|
||||
|
||||
/// Sets the amount of [`Peers`](crate::peer::Peer) known by the
|
||||
/// [`PeerManager`](crate::peer_manager::PeerManager).
|
||||
#[inline]
|
||||
fn peer_manager_known_peers(&self, _amount: usize) {}
|
||||
|
||||
/// The [`PeerManager`](crate::peer_manager::PeerManager) started an attempt to connect to a
|
||||
/// remote endpoint.
|
||||
#[inline]
|
||||
fn peer_manager_connection_attempted(&self) {}
|
||||
|
||||
/// The [`PeerManager`](crate::peer_manager::PeerManager) finished an attempt to connect to a
|
||||
/// remote endpoint. The connection could have failed.
|
||||
#[inline]
|
||||
fn peer_manager_connection_finished(&self) {}
|
||||
}
|
||||
134
mycelium/src/packet.rs
Normal file
134
mycelium/src/packet.rs
Normal file
@@ -0,0 +1,134 @@
|
||||
use bytes::{Buf, BufMut, BytesMut};
|
||||
pub use control::ControlPacket;
|
||||
pub use data::DataPacket;
|
||||
use tokio_util::codec::{Decoder, Encoder};
|
||||
|
||||
mod control;
|
||||
mod data;
|
||||
|
||||
/// Current version of the protocol being used.
|
||||
const PROTOCOL_VERSION: u8 = 1;
|
||||
|
||||
/// The size of a `Packet` header on the wire, in bytes.
|
||||
const PACKET_HEADER_SIZE: usize = 4;
|
||||
|
||||
#[derive(Debug, Clone)]
|
||||
pub enum Packet {
|
||||
DataPacket(DataPacket),
|
||||
ControlPacket(ControlPacket),
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone, Copy)]
|
||||
#[repr(u8)]
|
||||
pub enum PacketType {
|
||||
DataPacket = 0,
|
||||
ControlPacket = 1,
|
||||
}
|
||||
|
||||
pub struct Codec {
|
||||
packet_type: Option<PacketType>,
|
||||
data_packet_codec: data::Codec,
|
||||
control_packet_codec: control::Codec,
|
||||
}
|
||||
|
||||
impl Codec {
|
||||
pub fn new() -> Self {
|
||||
Codec {
|
||||
packet_type: None,
|
||||
data_packet_codec: data::Codec::new(),
|
||||
control_packet_codec: control::Codec::new(),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl Decoder for Codec {
|
||||
type Item = Packet;
|
||||
type Error = std::io::Error;
|
||||
|
||||
fn decode(&mut self, src: &mut BytesMut) -> Result<Option<Self::Item>, Self::Error> {
|
||||
// Determine the packet_type
|
||||
let packet_type = if let Some(packet_type) = self.packet_type {
|
||||
packet_type
|
||||
} else {
|
||||
// Check we can read the header
|
||||
if src.remaining() <= PACKET_HEADER_SIZE {
|
||||
return Ok(None);
|
||||
}
|
||||
|
||||
let mut header = [0; PACKET_HEADER_SIZE];
|
||||
header.copy_from_slice(&src[..PACKET_HEADER_SIZE]);
|
||||
src.advance(PACKET_HEADER_SIZE);
|
||||
|
||||
// For now it's a hard error to not follow the 1 defined protocol version
|
||||
if header[0] != PROTOCOL_VERSION {
|
||||
return Err(std::io::Error::new(
|
||||
std::io::ErrorKind::InvalidData,
|
||||
"Unknown protocol version",
|
||||
));
|
||||
};
|
||||
|
||||
let packet_type_byte = header[1];
|
||||
let packet_type = match packet_type_byte {
|
||||
0 => PacketType::DataPacket,
|
||||
1 => PacketType::ControlPacket,
|
||||
_ => {
|
||||
return Err(std::io::Error::new(
|
||||
std::io::ErrorKind::InvalidData,
|
||||
"Invalid packet type",
|
||||
));
|
||||
}
|
||||
};
|
||||
|
||||
self.packet_type = Some(packet_type);
|
||||
|
||||
packet_type
|
||||
};
|
||||
|
||||
// Decode packet based on determined packet_type
|
||||
match packet_type {
|
||||
PacketType::DataPacket => {
|
||||
match self.data_packet_codec.decode(src) {
|
||||
Ok(Some(p)) => {
|
||||
self.packet_type = None; // Reset state
|
||||
Ok(Some(Packet::DataPacket(p)))
|
||||
}
|
||||
Ok(None) => Ok(None),
|
||||
Err(e) => Err(e),
|
||||
}
|
||||
}
|
||||
PacketType::ControlPacket => {
|
||||
match self.control_packet_codec.decode(src) {
|
||||
Ok(Some(p)) => {
|
||||
self.packet_type = None; // Reset state
|
||||
Ok(Some(Packet::ControlPacket(p)))
|
||||
}
|
||||
Ok(None) => Ok(None),
|
||||
Err(e) => Err(e),
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl Encoder<Packet> for Codec {
|
||||
type Error = std::io::Error;
|
||||
|
||||
fn encode(&mut self, item: Packet, dst: &mut BytesMut) -> Result<(), Self::Error> {
|
||||
match item {
|
||||
Packet::DataPacket(datapacket) => {
|
||||
dst.put_slice(&[PROTOCOL_VERSION, 0, 0, 0]);
|
||||
self.data_packet_codec.encode(datapacket, dst)
|
||||
}
|
||||
Packet::ControlPacket(controlpacket) => {
|
||||
dst.put_slice(&[PROTOCOL_VERSION, 1, 0, 0]);
|
||||
self.control_packet_codec.encode(controlpacket, dst)
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl Default for Codec {
|
||||
fn default() -> Self {
|
||||
Self::new()
|
||||
}
|
||||
}
|
||||
64
mycelium/src/packet/control.rs
Normal file
64
mycelium/src/packet/control.rs
Normal file
@@ -0,0 +1,64 @@
|
||||
use std::{io, net::IpAddr, time::Duration};
|
||||
|
||||
use bytes::BytesMut;
|
||||
use tokio_util::codec::{Decoder, Encoder};
|
||||
|
||||
use crate::{
|
||||
babel, metric::Metric, peer::Peer, router_id::RouterId, sequence_number::SeqNo, subnet::Subnet,
|
||||
};
|
||||
|
||||
pub type ControlPacket = babel::Tlv;
|
||||
|
||||
pub struct Codec {
|
||||
// TODO: wrapper to make it easier to deserialize
|
||||
codec: babel::Codec,
|
||||
}
|
||||
|
||||
impl ControlPacket {
|
||||
pub fn new_hello(dest_peer: &Peer, interval: Duration) -> Self {
|
||||
let tlv: babel::Tlv =
|
||||
babel::Hello::new_unicast(dest_peer.hello_seqno(), (interval.as_millis() / 10) as u16)
|
||||
.into();
|
||||
dest_peer.increment_hello_seqno();
|
||||
tlv
|
||||
}
|
||||
|
||||
pub fn new_ihu(rx_cost: Metric, interval: Duration, dest_address: Option<IpAddr>) -> Self {
|
||||
babel::Ihu::new(rx_cost, (interval.as_millis() / 10) as u16, dest_address).into()
|
||||
}
|
||||
|
||||
pub fn new_update(
|
||||
interval: Duration,
|
||||
seqno: SeqNo,
|
||||
metric: Metric,
|
||||
subnet: Subnet,
|
||||
router_id: RouterId,
|
||||
) -> Self {
|
||||
babel::Update::new(interval, seqno, metric, subnet, router_id).into()
|
||||
}
|
||||
}
|
||||
|
||||
impl Codec {
|
||||
pub fn new() -> Self {
|
||||
Codec {
|
||||
codec: babel::Codec::new(),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl Decoder for Codec {
|
||||
type Item = ControlPacket;
|
||||
type Error = std::io::Error;
|
||||
|
||||
fn decode(&mut self, buf: &mut BytesMut) -> Result<Option<Self::Item>, Self::Error> {
|
||||
self.codec.decode(buf)
|
||||
}
|
||||
}
|
||||
|
||||
impl Encoder<ControlPacket> for Codec {
|
||||
type Error = io::Error;
|
||||
|
||||
fn encode(&mut self, message: ControlPacket, buf: &mut BytesMut) -> Result<(), Self::Error> {
|
||||
self.codec.encode(message, buf)
|
||||
}
|
||||
}
|
||||
154
mycelium/src/packet/data.rs
Normal file
154
mycelium/src/packet/data.rs
Normal file
@@ -0,0 +1,154 @@
|
||||
use std::net::Ipv6Addr;
|
||||
|
||||
use bytes::{Buf, BufMut, BytesMut};
|
||||
use tokio_util::codec::{Decoder, Encoder};
|
||||
|
||||
/// Size of the header start for a data packet (before the IP addresses).
|
||||
const DATA_PACKET_HEADER_SIZE: usize = 4;
|
||||
|
||||
/// Mask to extract data length from
|
||||
const DATA_PACKET_LEN_MASK: u32 = (1 << 16) - 1;
|
||||
|
||||
#[derive(Debug, Clone)]
|
||||
pub struct DataPacket {
|
||||
pub raw_data: Vec<u8>, // encrypted data itself, then append the nonce
|
||||
/// Max amount of hops for the packet.
|
||||
pub hop_limit: u8,
|
||||
pub src_ip: Ipv6Addr,
|
||||
pub dst_ip: Ipv6Addr,
|
||||
}
|
||||
|
||||
pub struct Codec {
|
||||
header_vals: Option<HeaderValues>,
|
||||
src_ip: Option<Ipv6Addr>,
|
||||
dest_ip: Option<Ipv6Addr>,
|
||||
}
|
||||
|
||||
/// Data from the DataPacket header.
|
||||
#[derive(Clone, Copy)]
|
||||
struct HeaderValues {
|
||||
len: u16,
|
||||
hop_limit: u8,
|
||||
}
|
||||
|
||||
impl Codec {
|
||||
pub fn new() -> Self {
|
||||
Codec {
|
||||
header_vals: None,
|
||||
src_ip: None,
|
||||
dest_ip: None,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl Decoder for Codec {
|
||||
type Item = DataPacket;
|
||||
type Error = std::io::Error;
|
||||
|
||||
fn decode(&mut self, src: &mut BytesMut) -> Result<Option<Self::Item>, Self::Error> {
|
||||
// Determine the length of the data
|
||||
let HeaderValues { len, hop_limit } = if let Some(header_vals) = self.header_vals {
|
||||
header_vals
|
||||
} else {
|
||||
// Check we have enough data to decode
|
||||
if src.len() < DATA_PACKET_HEADER_SIZE {
|
||||
return Ok(None);
|
||||
}
|
||||
|
||||
let raw_header = src.get_u32();
|
||||
// Hop limit is the last 8 bits.
|
||||
let hop_limit = (raw_header & 0xFF) as u8;
|
||||
let data_len = ((raw_header >> 8) & DATA_PACKET_LEN_MASK) as u16;
|
||||
let header_vals = HeaderValues {
|
||||
len: data_len,
|
||||
hop_limit,
|
||||
};
|
||||
|
||||
self.header_vals = Some(header_vals);
|
||||
|
||||
header_vals
|
||||
};
|
||||
|
||||
let data_len = len as usize;
|
||||
|
||||
// Determine the source IP
|
||||
let src_ip = if let Some(src_ip) = self.src_ip {
|
||||
src_ip
|
||||
} else {
|
||||
if src.len() < 16 {
|
||||
return Ok(None);
|
||||
}
|
||||
|
||||
// Decode octets
|
||||
let mut ip_bytes = [0u8; 16];
|
||||
ip_bytes.copy_from_slice(&src[..16]);
|
||||
let src_ip = Ipv6Addr::from(ip_bytes);
|
||||
src.advance(16);
|
||||
|
||||
self.src_ip = Some(src_ip);
|
||||
src_ip
|
||||
};
|
||||
|
||||
// Determine the destination IP
|
||||
let dest_ip = if let Some(dest_ip) = self.dest_ip {
|
||||
dest_ip
|
||||
} else {
|
||||
if src.len() < 16 {
|
||||
return Ok(None);
|
||||
}
|
||||
|
||||
// Decode octets
|
||||
let mut ip_bytes = [0u8; 16];
|
||||
ip_bytes.copy_from_slice(&src[..16]);
|
||||
let dest_ip = Ipv6Addr::from(ip_bytes);
|
||||
src.advance(16);
|
||||
|
||||
self.dest_ip = Some(dest_ip);
|
||||
dest_ip
|
||||
};
|
||||
|
||||
// Check we have enough data to decode
|
||||
if src.len() < data_len {
|
||||
return Ok(None);
|
||||
}
|
||||
|
||||
// Decode octets
|
||||
let mut data = vec![0u8; data_len];
|
||||
data.copy_from_slice(&src[..data_len]);
|
||||
src.advance(data_len);
|
||||
|
||||
// Reset state
|
||||
self.header_vals = None;
|
||||
self.dest_ip = None;
|
||||
self.src_ip = None;
|
||||
|
||||
Ok(Some(DataPacket {
|
||||
raw_data: data,
|
||||
hop_limit,
|
||||
dst_ip: dest_ip,
|
||||
src_ip,
|
||||
}))
|
||||
}
|
||||
}
|
||||
|
||||
impl Encoder<DataPacket> for Codec {
|
||||
type Error = std::io::Error;
|
||||
|
||||
fn encode(&mut self, item: DataPacket, dst: &mut BytesMut) -> Result<(), Self::Error> {
|
||||
dst.reserve(item.raw_data.len() + DATA_PACKET_HEADER_SIZE + 16 + 16);
|
||||
let mut raw_header = 0;
|
||||
// Add length of the data
|
||||
raw_header |= (item.raw_data.len() as u32) << 8;
|
||||
// And hop limit
|
||||
raw_header |= item.hop_limit as u32;
|
||||
dst.put_u32(raw_header);
|
||||
// Write the source IP
|
||||
dst.put_slice(&item.src_ip.octets());
|
||||
// Write the destination IP
|
||||
dst.put_slice(&item.dst_ip.octets());
|
||||
// Write the data
|
||||
dst.extend_from_slice(&item.raw_data);
|
||||
|
||||
Ok(())
|
||||
}
|
||||
}
|
||||
401
mycelium/src/peer.rs
Normal file
401
mycelium/src/peer.rs
Normal file
@@ -0,0 +1,401 @@
|
||||
use futures::{SinkExt, StreamExt};
|
||||
use std::{
|
||||
error::Error,
|
||||
io,
|
||||
sync::{
|
||||
atomic::{AtomicBool, AtomicU64, Ordering},
|
||||
Arc, RwLock, Weak,
|
||||
},
|
||||
};
|
||||
use tokio::{
|
||||
select,
|
||||
sync::{mpsc, Notify},
|
||||
};
|
||||
use tokio_util::codec::Framed;
|
||||
use tracing::{debug, error, info, trace};
|
||||
|
||||
use crate::{
|
||||
connection::{self, Connection},
|
||||
packet::{self, Packet},
|
||||
};
|
||||
use crate::{
|
||||
packet::{ControlPacket, DataPacket},
|
||||
sequence_number::SeqNo,
|
||||
};
|
||||
|
||||
/// The maximum amount of packets to immediately send if they are ready when the first one is
|
||||
/// received.
|
||||
const PACKET_COALESCE_WINDOW: usize = 50;
|
||||
|
||||
/// The default link cost assigned to new peers before their actual cost is known.
|
||||
///
|
||||
/// In theory, the best value would be U16::MAX - 1, however this value would take too long to be
|
||||
/// flushed out of the smoothed metric. A default of a 1000 (1 second) should be sufficiently large
|
||||
/// to cover very bad connections, so they also converge to a smaller value. While there is no
|
||||
/// issue with converging to a higher value (in other words, underestimating the latency to a
|
||||
/// peer), this means that bad peers would briefly be more likely to be selected. Additionally,
|
||||
/// since the latency increases, downstream peers would eventually find that the announced route
|
||||
/// would become unfeasible, and send a seqno request (which should solve this efficiently). As a
|
||||
/// tradeoff, it means it takes longer for new peers in the network to decrease to their actual
|
||||
/// metric (in comparisson with a lower starting metric), though this is in itself a usefull thing
|
||||
/// to have as it means peers joining the network would need to have some stability before being
|
||||
/// selected as hop.
|
||||
const DEFAULT_LINK_COST: u16 = 1000;
|
||||
|
||||
/// Multiplier for smoothed metric calculation of the existing smoothed metric.
|
||||
const EXISTING_METRIC_FACTOR: u32 = 9;
|
||||
/// Divisor for smoothed metric calcuation of the combined metric
|
||||
const TOTAL_METRIC_DIVISOR: u32 = 10;
|
||||
|
||||
#[derive(Debug, Clone)]
|
||||
/// A peer represents a directly connected participant in the network.
|
||||
pub struct Peer {
|
||||
inner: Arc<PeerInner>,
|
||||
}
|
||||
|
||||
/// A weak reference to a peer, which does not prevent it from being cleaned up. This can be used
|
||||
/// to check liveliness of the [`Peer`] instance it originated from.
|
||||
pub struct PeerRef {
|
||||
inner: Weak<PeerInner>,
|
||||
}
|
||||
|
||||
impl Peer {
|
||||
pub fn new<C: Connection + Unpin + Send + 'static>(
|
||||
router_data_tx: mpsc::Sender<DataPacket>,
|
||||
router_control_tx: mpsc::UnboundedSender<(ControlPacket, Peer)>,
|
||||
connection: C,
|
||||
dead_peer_sink: mpsc::Sender<Peer>,
|
||||
bytes_written: Arc<AtomicU64>,
|
||||
bytes_read: Arc<AtomicU64>,
|
||||
) -> Result<Self, io::Error> {
|
||||
// Wrap connection so we can get access to the counters.
|
||||
let connection = connection::Tracked::new(bytes_read, bytes_written, connection);
|
||||
|
||||
// Data channel for peer
|
||||
let (to_peer_data, mut from_routing_data) = mpsc::unbounded_channel::<DataPacket>();
|
||||
// Control channel for peer
|
||||
let (to_peer_control, mut from_routing_control) =
|
||||
mpsc::unbounded_channel::<ControlPacket>();
|
||||
let death_notifier = Arc::new(Notify::new());
|
||||
let death_watcher = death_notifier.clone();
|
||||
let peer = Peer {
|
||||
inner: Arc::new(PeerInner {
|
||||
state: RwLock::new(PeerState::new()),
|
||||
to_peer_data,
|
||||
to_peer_control,
|
||||
connection_identifier: connection.identifier()?,
|
||||
static_link_cost: connection.static_link_cost()?,
|
||||
death_notifier,
|
||||
alive: AtomicBool::new(true),
|
||||
}),
|
||||
};
|
||||
|
||||
// Framed for peer
|
||||
// Used to send and receive packets from a TCP stream
|
||||
let framed = Framed::with_capacity(connection, packet::Codec::new(), 128 << 10);
|
||||
let (mut sink, mut stream) = framed.split();
|
||||
|
||||
{
|
||||
let peer = peer.clone();
|
||||
|
||||
tokio::spawn(async move {
|
||||
let mut needs_flush = false;
|
||||
loop {
|
||||
select! {
|
||||
// Received over the TCP stream
|
||||
frame = stream.next() => {
|
||||
match frame {
|
||||
Some(Ok(packet)) => {
|
||||
match packet {
|
||||
Packet::DataPacket(packet) => {
|
||||
// An error here means the receiver is dropped/closed,
|
||||
// this is not recoverable.
|
||||
if let Err(error) = router_data_tx.send(packet).await{
|
||||
error!("Error sending to to_routing_data: {}", error);
|
||||
break
|
||||
}
|
||||
}
|
||||
Packet::ControlPacket(packet) => {
|
||||
if let Err(error) = router_control_tx.send((packet, peer.clone())) {
|
||||
// An error here means the receiver is dropped/closed,
|
||||
// this is not recoverable.
|
||||
error!("Error sending to to_routing_control: {}", error);
|
||||
break
|
||||
}
|
||||
|
||||
}
|
||||
}
|
||||
}
|
||||
Some(Err(e)) => {
|
||||
error!("Frame error from {}: {e}", peer.connection_identifier());
|
||||
break;
|
||||
},
|
||||
None => {
|
||||
info!("Stream to {} is closed", peer.connection_identifier());
|
||||
break;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
rv = from_routing_data.recv(), if !needs_flush => {
|
||||
match rv {
|
||||
None => break,
|
||||
Some(packet) => {
|
||||
|
||||
needs_flush = true;
|
||||
|
||||
if let Err(e) = sink.feed(Packet::DataPacket(packet)).await {
|
||||
error!("Failed to feed data packet to connection: {e}");
|
||||
break
|
||||
}
|
||||
|
||||
|
||||
for _ in 1..PACKET_COALESCE_WINDOW {
|
||||
// There can be 2 cases of errors here, empty channel and no more
|
||||
// senders. In both cases we don't really care at this point.
|
||||
if let Ok(packet) = from_routing_data.try_recv() {
|
||||
if let Err(e) = sink.feed(Packet::DataPacket(packet)).await {
|
||||
error!("Failed to feed data packet to connection: {e}");
|
||||
break
|
||||
}
|
||||
trace!("Instantly queued ready packet to transfer to peer");
|
||||
} else {
|
||||
// No packets ready, flush currently buffered ones
|
||||
break
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
rv = from_routing_control.recv(), if !needs_flush => {
|
||||
match rv {
|
||||
None => break,
|
||||
Some(packet) => {
|
||||
|
||||
needs_flush = true;
|
||||
|
||||
if let Err(e) = sink.feed(Packet::ControlPacket(packet)).await {
|
||||
error!("Failed to feed control packet to connection: {e}");
|
||||
break
|
||||
}
|
||||
|
||||
for _ in 1..PACKET_COALESCE_WINDOW {
|
||||
// There can be 2 cases of errors here, empty channel and no more
|
||||
// senders. In both cases we don't really care at this point.
|
||||
if let Ok(packet) = from_routing_control.try_recv() {
|
||||
if let Err(e) = sink.feed(Packet::ControlPacket(packet)).await {
|
||||
error!("Failed to feed data packet to connection: {e}");
|
||||
break
|
||||
}
|
||||
} else {
|
||||
// No packets ready, flush currently buffered ones
|
||||
break
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
r = sink.flush(), if needs_flush => {
|
||||
if let Err(err) = r {
|
||||
error!("Failed to flush peer connection: {err}");
|
||||
break
|
||||
}
|
||||
needs_flush = false;
|
||||
}
|
||||
|
||||
_ = death_watcher.notified() => {
|
||||
// Attempt gracefull shutdown
|
||||
let mut framed = sink.reunite(stream).expect("SplitSink and SplitStream here can only be part of the same original Framned; Qed");
|
||||
let _ = framed.close().await;
|
||||
break;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// Notify router we are dead, also modify our internal state to declare that.
|
||||
// Relaxed ordering is fine, we just care that the variable is set.
|
||||
peer.inner.alive.store(false, Ordering::Relaxed);
|
||||
let remote_id = peer.connection_identifier().clone();
|
||||
debug!("Notifying router peer {remote_id} is dead");
|
||||
if let Err(e) = dead_peer_sink.send(peer).await {
|
||||
error!("Peer {remote_id} could not notify router of termination: {e}");
|
||||
}
|
||||
});
|
||||
};
|
||||
|
||||
Ok(peer)
|
||||
}
|
||||
|
||||
/// Get current sequence number for this peer.
|
||||
pub fn hello_seqno(&self) -> SeqNo {
|
||||
self.inner.state.read().unwrap().hello_seqno
|
||||
}
|
||||
|
||||
/// Adds 1 to the sequence number of this peer .
|
||||
pub fn increment_hello_seqno(&self) {
|
||||
self.inner.state.write().unwrap().hello_seqno += 1;
|
||||
}
|
||||
|
||||
pub fn time_last_received_hello(&self) -> tokio::time::Instant {
|
||||
self.inner.state.read().unwrap().time_last_received_hello
|
||||
}
|
||||
|
||||
pub fn set_time_last_received_hello(&self, time: tokio::time::Instant) {
|
||||
self.inner.state.write().unwrap().time_last_received_hello = time
|
||||
}
|
||||
|
||||
/// For sending data packets towards a peer instance on this node.
|
||||
/// It's send over the to_peer_data channel and read from the corresponding receiver.
|
||||
/// The receiver sends the packet over the TCP stream towards the destined peer instance on another node
|
||||
pub fn send_data_packet(&self, data_packet: DataPacket) -> Result<(), Box<dyn Error>> {
|
||||
Ok(self.inner.to_peer_data.send(data_packet)?)
|
||||
}
|
||||
|
||||
/// For sending control packets towards a peer instance on this node.
|
||||
/// It's send over the to_peer_control channel and read from the corresponding receiver.
|
||||
/// The receiver sends the packet over the TCP stream towards the destined peer instance on another node
|
||||
pub fn send_control_packet(&self, control_packet: ControlPacket) -> Result<(), Box<dyn Error>> {
|
||||
Ok(self.inner.to_peer_control.send(control_packet)?)
|
||||
}
|
||||
|
||||
/// Get the cost to use the peer, i.e. the additional impact on the [`crate::metric::Metric`]
|
||||
/// for using this `Peer`.
|
||||
///
|
||||
/// This is a smoothed value, which is calculated over the recent history of link cost.
|
||||
pub fn link_cost(&self) -> u16 {
|
||||
self.inner.state.read().unwrap().link_cost + self.inner.static_link_cost
|
||||
}
|
||||
|
||||
/// Sets the link cost based on the provided value.
|
||||
///
|
||||
/// The link cost is not set to the given value, but rather to an average of recent values.
|
||||
/// This makes sure short-lived, hard spikes of the link cost of a peer don't influence the
|
||||
/// routing.
|
||||
pub fn set_link_cost(&self, new_link_cost: u16) {
|
||||
// Calculate new link cost by multiplying (i.e. scaling) old and new link cost and
|
||||
// averaging them.
|
||||
let mut inner = self.inner.state.write().unwrap();
|
||||
inner.link_cost = (((inner.link_cost as u32) * EXISTING_METRIC_FACTOR
|
||||
+ (new_link_cost as u32) * (TOTAL_METRIC_DIVISOR - EXISTING_METRIC_FACTOR))
|
||||
/ TOTAL_METRIC_DIVISOR) as u16;
|
||||
}
|
||||
|
||||
/// Identifier for the connection to the `Peer`.
|
||||
pub fn connection_identifier(&self) -> &String {
|
||||
&self.inner.connection_identifier
|
||||
}
|
||||
|
||||
pub fn time_last_received_ihu(&self) -> tokio::time::Instant {
|
||||
self.inner.state.read().unwrap().time_last_received_ihu
|
||||
}
|
||||
|
||||
pub fn set_time_last_received_ihu(&self, time: tokio::time::Instant) {
|
||||
self.inner.state.write().unwrap().time_last_received_ihu = time
|
||||
}
|
||||
|
||||
/// Notify this `Peer` that it died.
|
||||
///
|
||||
/// While some [`Connection`] types can immediately detect that the connection itself is
|
||||
/// broken, not all of them can. In this scenario, we need to rely on an outside signal to tell
|
||||
/// us that we have, in fact, died.
|
||||
pub fn died(&self) {
|
||||
self.inner.alive.store(false, Ordering::Relaxed);
|
||||
self.inner.death_notifier.notify_one();
|
||||
}
|
||||
|
||||
/// Checks if the connection of this `Peer` is still alive.
|
||||
///
|
||||
/// For connection types which don't have (real time) state information, this might return a
|
||||
/// false positive if the connection has actually died, but the Peer did not notice this (yet)
|
||||
/// and hasn't been informed.
|
||||
pub fn alive(&self) -> bool {
|
||||
self.inner.alive.load(Ordering::Relaxed)
|
||||
}
|
||||
|
||||
/// Create a new [`PeerRef`] that refers to this `Peer` instance.
|
||||
pub fn refer(&self) -> PeerRef {
|
||||
PeerRef {
|
||||
inner: Arc::downgrade(&self.inner),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl PeerRef {
|
||||
/// Contructs a new `PeerRef` which is not associated with any actually [`Peer`].
|
||||
/// [`PeerRef::alive`] will always return false when called on this `PeerRef`.
|
||||
pub fn new() -> Self {
|
||||
PeerRef { inner: Weak::new() }
|
||||
}
|
||||
|
||||
/// Check if the connection of the [`Peer`] this `PeerRef` points to is still alive.
|
||||
pub fn alive(&self) -> bool {
|
||||
if let Some(peer) = self.inner.upgrade() {
|
||||
peer.alive.load(Ordering::Relaxed)
|
||||
} else {
|
||||
false
|
||||
}
|
||||
}
|
||||
|
||||
/// Attempts to convert this `PeerRef` into a full [`Peer`].
|
||||
pub fn upgrade(&self) -> Option<Peer> {
|
||||
self.inner.upgrade().map(|inner| Peer { inner })
|
||||
}
|
||||
}
|
||||
|
||||
impl Default for PeerRef {
|
||||
fn default() -> Self {
|
||||
Self::new()
|
||||
}
|
||||
}
|
||||
|
||||
impl PartialEq for Peer {
|
||||
fn eq(&self, other: &Self) -> bool {
|
||||
Arc::ptr_eq(&self.inner, &other.inner)
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Debug)]
|
||||
struct PeerInner {
|
||||
state: RwLock<PeerState>,
|
||||
to_peer_data: mpsc::UnboundedSender<DataPacket>,
|
||||
to_peer_control: mpsc::UnboundedSender<ControlPacket>,
|
||||
/// Used to identify peer based on its connection params.
|
||||
connection_identifier: String,
|
||||
/// Static cost of using this link, to be added to the announced metric for routes through this
|
||||
/// Peer.
|
||||
static_link_cost: u16,
|
||||
/// Channel to notify the connection of its decease.
|
||||
death_notifier: Arc<Notify>,
|
||||
/// Keep track if the connection is alive.
|
||||
alive: AtomicBool,
|
||||
}
|
||||
|
||||
#[derive(Debug)]
|
||||
struct PeerState {
|
||||
hello_seqno: SeqNo,
|
||||
time_last_received_hello: tokio::time::Instant,
|
||||
link_cost: u16,
|
||||
time_last_received_ihu: tokio::time::Instant,
|
||||
}
|
||||
|
||||
impl PeerState {
|
||||
/// Create a new `PeerInner`, holding the mutable state of a [`Peer`]
|
||||
fn new() -> Self {
|
||||
// Initialize last_sent_hello_seqno to 0
|
||||
let hello_seqno = SeqNo::default();
|
||||
let link_cost = DEFAULT_LINK_COST;
|
||||
// Initialize time_last_received_hello to now
|
||||
let time_last_received_hello = tokio::time::Instant::now();
|
||||
// Initialiwe time_last_send_ihu
|
||||
let time_last_received_ihu = tokio::time::Instant::now();
|
||||
|
||||
Self {
|
||||
hello_seqno,
|
||||
link_cost,
|
||||
time_last_received_ihu,
|
||||
time_last_received_hello,
|
||||
}
|
||||
}
|
||||
}
|
||||
1361
mycelium/src/peer_manager.rs
Normal file
1361
mycelium/src/peer_manager.rs
Normal file
File diff suppressed because it is too large
Load Diff
2216
mycelium/src/router.rs
Normal file
2216
mycelium/src/router.rs
Normal file
File diff suppressed because it is too large
Load Diff
60
mycelium/src/router_id.rs
Normal file
60
mycelium/src/router_id.rs
Normal file
@@ -0,0 +1,60 @@
|
||||
use core::fmt;
|
||||
|
||||
use crate::crypto::PublicKey;
|
||||
|
||||
/// A `RouterId` uniquely identifies a router in the network.
|
||||
#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)]
|
||||
pub struct RouterId {
|
||||
pk: PublicKey,
|
||||
zone: [u8; 2],
|
||||
rnd: [u8; 6],
|
||||
}
|
||||
|
||||
impl RouterId {
|
||||
/// Size in bytes of a `RouterId`
|
||||
pub const BYTE_SIZE: usize = 40;
|
||||
|
||||
/// Create a new `RouterId` from a [`PublicKey`].
|
||||
pub fn new(pk: PublicKey) -> Self {
|
||||
Self {
|
||||
pk,
|
||||
zone: [0; 2],
|
||||
rnd: rand::random(),
|
||||
}
|
||||
}
|
||||
|
||||
/// View this `RouterId` as a byte array.
|
||||
pub fn as_bytes(&self) -> [u8; Self::BYTE_SIZE] {
|
||||
let mut out = [0; Self::BYTE_SIZE];
|
||||
out[..32].copy_from_slice(self.pk.as_bytes());
|
||||
out[32..34].copy_from_slice(&self.zone);
|
||||
out[34..].copy_from_slice(&self.rnd);
|
||||
out
|
||||
}
|
||||
|
||||
/// Converts this `RouterId` to a [`PublicKey`].
|
||||
pub fn to_pubkey(self) -> PublicKey {
|
||||
self.pk
|
||||
}
|
||||
}
|
||||
|
||||
impl From<[u8; Self::BYTE_SIZE]> for RouterId {
|
||||
fn from(bytes: [u8; Self::BYTE_SIZE]) -> RouterId {
|
||||
RouterId {
|
||||
pk: PublicKey::from(<&[u8] as TryInto<[u8; 32]>>::try_into(&bytes[..32]).unwrap()),
|
||||
zone: bytes[32..34].try_into().unwrap(),
|
||||
rnd: bytes[34..Self::BYTE_SIZE].try_into().unwrap(),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl fmt::Display for RouterId {
|
||||
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
|
||||
let RouterId { pk, zone, rnd } = self;
|
||||
f.write_fmt(format_args!(
|
||||
"{pk}-{}-{}",
|
||||
faster_hex::hex_string(zone),
|
||||
faster_hex::hex_string(rnd)
|
||||
))
|
||||
}
|
||||
}
|
||||
687
mycelium/src/routing_table.rs
Normal file
687
mycelium/src/routing_table.rs
Normal file
@@ -0,0 +1,687 @@
|
||||
use std::{
|
||||
net::{IpAddr, Ipv6Addr},
|
||||
ops::Deref,
|
||||
sync::{Arc, Mutex, MutexGuard},
|
||||
};
|
||||
|
||||
use ip_network_table_deps_treebitmap::IpLookupTable;
|
||||
use iter::{RoutingTableNoRouteIter, RoutingTableQueryIter};
|
||||
use subnet_entry::SubnetEntry;
|
||||
use tokio::{select, sync::mpsc, time::Duration};
|
||||
use tokio_util::sync::CancellationToken;
|
||||
use tracing::{error, trace};
|
||||
|
||||
use crate::{crypto::SharedSecret, peer::Peer, subnet::Subnet};
|
||||
|
||||
pub use iter::RoutingTableIter;
|
||||
pub use iter_mut::RoutingTableIterMut;
|
||||
pub use no_route::NoRouteSubnet;
|
||||
pub use queried_subnet::QueriedSubnet;
|
||||
pub use route_entry::RouteEntry;
|
||||
pub use route_key::RouteKey;
|
||||
pub use route_list::RouteList;
|
||||
|
||||
mod iter;
|
||||
mod iter_mut;
|
||||
mod no_route;
|
||||
mod queried_subnet;
|
||||
mod route_entry;
|
||||
mod route_key;
|
||||
mod route_list;
|
||||
mod subnet_entry;
|
||||
|
||||
const NO_ROUTE_EXPIRATION: Duration = Duration::from_secs(60);
|
||||
|
||||
pub enum Routes {
|
||||
Exist(RouteListReadGuard),
|
||||
Queried,
|
||||
NoRoute,
|
||||
None,
|
||||
}
|
||||
|
||||
impl Routes {
|
||||
/// Returns the selected route if one exists.
|
||||
pub fn selected(&self) -> Option<&RouteEntry> {
|
||||
if let Routes::Exist(routes) = self {
|
||||
routes.selected()
|
||||
} else {
|
||||
None
|
||||
}
|
||||
}
|
||||
|
||||
/// Returns true if there are no routes
|
||||
pub fn is_none(&self) -> bool {
|
||||
!matches!(self, Routes::Exist { .. })
|
||||
}
|
||||
}
|
||||
|
||||
impl From<&SubnetEntry> for Routes {
|
||||
fn from(value: &SubnetEntry) -> Self {
|
||||
match value {
|
||||
SubnetEntry::Exists { list } => {
|
||||
Routes::Exist(RouteListReadGuard { inner: list.load() })
|
||||
}
|
||||
SubnetEntry::Queried { .. } => Routes::Queried,
|
||||
SubnetEntry::NoRoute { .. } => Routes::NoRoute,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl From<Option<&SubnetEntry>> for Routes {
|
||||
fn from(value: Option<&SubnetEntry>) -> Self {
|
||||
match value {
|
||||
Some(v) => v.into(),
|
||||
None => Routes::None,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
/// The routing table holds a list of route entries for every known subnet.
|
||||
#[derive(Clone)]
|
||||
pub struct RoutingTable {
|
||||
writer: Arc<Mutex<left_right::WriteHandle<RoutingTableInner, RoutingTableOplogEntry>>>,
|
||||
reader: left_right::ReadHandle<RoutingTableInner>,
|
||||
|
||||
shared: Arc<RoutingTableShared>,
|
||||
}
|
||||
|
||||
struct RoutingTableShared {
|
||||
expired_route_entry_sink: mpsc::Sender<RouteKey>,
|
||||
cancel_token: CancellationToken,
|
||||
}
|
||||
|
||||
#[derive(Default)]
|
||||
struct RoutingTableInner {
|
||||
table: IpLookupTable<Ipv6Addr, Arc<SubnetEntry>>,
|
||||
}
|
||||
|
||||
/// Hold an exclusive write lock over the routing table. While this item is in scope, no other
|
||||
/// calls can get a mutable refernce to the content of a routing table. Once this guard goes out of
|
||||
/// scope, changes to the contained RouteList will be applied.
|
||||
pub struct WriteGuard<'a> {
|
||||
routing_table: &'a RoutingTable,
|
||||
/// Owned copy of the RouteList, this is populated once mutable access the the RouteList has
|
||||
/// been requested.
|
||||
value: Arc<SubnetEntry>,
|
||||
/// Did the RouteList exist initially?
|
||||
exists: bool,
|
||||
/// The subnet we are writing to.
|
||||
subnet: Subnet,
|
||||
expired_route_entry_sink: mpsc::Sender<RouteKey>,
|
||||
cancellation_token: CancellationToken,
|
||||
}
|
||||
|
||||
impl RoutingTable {
|
||||
/// Create a new empty RoutingTable. The passed channel is used to notify an external observer
|
||||
/// of route entry expiration events. It is the callers responsibility to ensure these events
|
||||
/// are properly handled.
|
||||
///
|
||||
/// # Panics
|
||||
///
|
||||
/// This will panic if not executed in the context of a tokio runtime.
|
||||
pub fn new(expired_route_entry_sink: mpsc::Sender<RouteKey>) -> Self {
|
||||
let (writer, reader) = left_right::new();
|
||||
let writer = Arc::new(Mutex::new(writer));
|
||||
|
||||
let cancel_token = CancellationToken::new();
|
||||
let shared = Arc::new(RoutingTableShared {
|
||||
expired_route_entry_sink,
|
||||
cancel_token,
|
||||
});
|
||||
|
||||
RoutingTable {
|
||||
writer,
|
||||
reader,
|
||||
shared,
|
||||
}
|
||||
}
|
||||
|
||||
/// Get a list of the routes for the most precises [`Subnet`] known which contains the given
|
||||
/// [`IpAddr`].
|
||||
pub fn best_routes(&self, ip: IpAddr) -> Routes {
|
||||
let IpAddr::V6(ip) = ip else {
|
||||
panic!("Only IPv6 is supported currently");
|
||||
};
|
||||
self.reader
|
||||
.enter()
|
||||
.expect("Write handle is saved on the router so it is not dropped yet.")
|
||||
.table
|
||||
.longest_match(ip)
|
||||
.map(|(_, _, rl)| rl.as_ref())
|
||||
.into()
|
||||
}
|
||||
|
||||
/// Get a list of all routes for the given subnet. Changes to the RoutingTable after this
|
||||
/// method returns will not be visible and require this method to be called again to be
|
||||
/// observed.
|
||||
pub fn routes(&self, subnet: Subnet) -> Routes {
|
||||
let subnet_ip = if let IpAddr::V6(ip) = subnet.address() {
|
||||
ip
|
||||
} else {
|
||||
return Routes::None;
|
||||
};
|
||||
|
||||
self.reader
|
||||
.enter()
|
||||
.expect("Write handle is saved on the router so it is not dropped yet.")
|
||||
.table
|
||||
.exact_match(subnet_ip, subnet.prefix_len().into())
|
||||
.map(Arc::as_ref)
|
||||
.into()
|
||||
}
|
||||
|
||||
/// Gets continued read access to the `RoutingTable`. While the returned
|
||||
/// [`guard`](RoutingTableReadGuard) is held, updates to the `RoutingTable` will be blocked.
|
||||
pub fn read(&self) -> RoutingTableReadGuard {
|
||||
RoutingTableReadGuard {
|
||||
guard: self
|
||||
.reader
|
||||
.enter()
|
||||
.expect("Write handle is saved on RoutingTable, so this is always Some; qed"),
|
||||
}
|
||||
}
|
||||
|
||||
/// Locks the `RoutingTable` for continued write access. While the returned
|
||||
/// [`guard`](RoutingTableWriteGuard) is held, methods trying to mutate the `RoutingTable`, or
|
||||
/// get mutable access otherwise, will be blocked. When the [`guard`](`RoutingTableWriteGuard`)
|
||||
/// is dropped, all queued changes will be applied.
|
||||
pub fn write(&self) -> RoutingTableWriteGuard {
|
||||
RoutingTableWriteGuard {
|
||||
write_guard: self.writer.lock().unwrap(),
|
||||
read_guard: self
|
||||
.reader
|
||||
.enter()
|
||||
.expect("Write handle is saved on RoutingTable, so this is always Some; qed"),
|
||||
expired_route_entry_sink: self.shared.expired_route_entry_sink.clone(),
|
||||
cancel_token: self.shared.cancel_token.clone(),
|
||||
}
|
||||
}
|
||||
|
||||
/// Get mutable access to the list of routes for the given [`Subnet`].
|
||||
pub fn routes_mut(&self, subnet: Subnet) -> Option<WriteGuard> {
|
||||
let subnet_address = if let IpAddr::V6(ip) = subnet.address() {
|
||||
ip
|
||||
} else {
|
||||
panic!("IP v4 addresses are not supported")
|
||||
};
|
||||
|
||||
let value = self
|
||||
.reader
|
||||
.enter()
|
||||
.expect("Write handle is saved next to read handle so this is always Some; qed")
|
||||
.table
|
||||
.exact_match(subnet_address, subnet.prefix_len().into())?
|
||||
.clone();
|
||||
|
||||
if matches!(*value, SubnetEntry::Exists { .. }) {
|
||||
Some(WriteGuard {
|
||||
routing_table: self,
|
||||
// If we didn't find a route list in the route table we create a new empty list,
|
||||
// therefore we immediately own it.
|
||||
value,
|
||||
exists: true,
|
||||
subnet,
|
||||
expired_route_entry_sink: self.shared.expired_route_entry_sink.clone(),
|
||||
cancellation_token: self.shared.cancel_token.clone(),
|
||||
})
|
||||
} else {
|
||||
None
|
||||
}
|
||||
}
|
||||
|
||||
/// Adds a new [`Subnet`] to the `RoutingTable`. The returned [`WriteGuard`] can be used to
|
||||
/// insert entries. If no entry is inserted before the guard is dropped, the [`Subnet`] won't
|
||||
/// be added.
|
||||
pub fn add_subnet(&self, subnet: Subnet, shared_secret: SharedSecret) -> WriteGuard {
|
||||
if !matches!(subnet.address(), IpAddr::V6(_)) {
|
||||
panic!("IP v4 addresses are not supported")
|
||||
};
|
||||
|
||||
let value = Arc::new(SubnetEntry::Exists {
|
||||
list: Arc::new(RouteList::new(shared_secret)).into(),
|
||||
});
|
||||
|
||||
WriteGuard {
|
||||
routing_table: self,
|
||||
value,
|
||||
exists: false,
|
||||
subnet,
|
||||
expired_route_entry_sink: self.shared.expired_route_entry_sink.clone(),
|
||||
cancellation_token: self.shared.cancel_token.clone(),
|
||||
}
|
||||
}
|
||||
|
||||
/// Gets the selected route for an IpAddr if one exists.
|
||||
///
|
||||
/// # Panics
|
||||
///
|
||||
/// This will panic if the IP address is not an IPV6 address.
|
||||
pub fn selected_route(&self, address: IpAddr) -> Option<RouteEntry> {
|
||||
let IpAddr::V6(ip) = address else {
|
||||
panic!("IP v4 addresses are not supported")
|
||||
};
|
||||
self.reader
|
||||
.enter()
|
||||
.expect("Write handle is saved on RoutingTable, so this is always Some; qed")
|
||||
.table
|
||||
.longest_match(ip)
|
||||
.and_then(|(_, _, rl)| {
|
||||
let SubnetEntry::Exists { list } = &**rl else {
|
||||
return None;
|
||||
};
|
||||
let rl = list.load();
|
||||
if rl.is_empty() || !rl[0].selected() {
|
||||
None
|
||||
} else {
|
||||
Some(rl[0].clone())
|
||||
}
|
||||
})
|
||||
}
|
||||
|
||||
/// Marks a subnet as queried in the route table.
|
||||
///
|
||||
/// This function will not do anything if the subnet contains valid routes.
|
||||
pub fn mark_queried(&self, subnet: Subnet, query_timeout: tokio::time::Instant) {
|
||||
if !matches!(subnet.address(), IpAddr::V6(_)) {
|
||||
panic!("IP v4 addresses are not supported")
|
||||
};
|
||||
|
||||
// Start a task to expire the queried state if we didn't have any results in time.
|
||||
{
|
||||
// We only need the write handle in the task
|
||||
let writer = self.writer.clone();
|
||||
let cancel_token = self.shared.cancel_token.clone();
|
||||
tokio::task::spawn(async move {
|
||||
select! {
|
||||
_ = cancel_token.cancelled() => {
|
||||
// Future got cancelled, nothing to do
|
||||
return
|
||||
}
|
||||
_ = tokio::time::sleep_until(query_timeout) => {
|
||||
// Timeout fired, mark as no route
|
||||
}
|
||||
}
|
||||
|
||||
let expiry = tokio::time::Instant::now() + NO_ROUTE_EXPIRATION;
|
||||
|
||||
// Scope this so the lock for the write_handle goes out of scope when we are done
|
||||
// here, as we don't want to hold the write_handle lock while sleeping for the
|
||||
// second timeout.
|
||||
{
|
||||
let mut write_handle = writer.lock().expect("Can lock writer");
|
||||
write_handle.append(RoutingTableOplogEntry::QueryExpired(
|
||||
subnet,
|
||||
Arc::new(SubnetEntry::NoRoute { expiry }),
|
||||
));
|
||||
write_handle.flush();
|
||||
}
|
||||
|
||||
// TODO: Check if we are indeed marked as NoRoute here, if we aren't this can be
|
||||
// cancelled now
|
||||
|
||||
select! {
|
||||
_ = cancel_token.cancelled() => {
|
||||
// Future got cancelled, nothing to do
|
||||
return
|
||||
}
|
||||
_ = tokio::time::sleep_until(expiry) => {
|
||||
// Timeout fired, remove no route entry
|
||||
}
|
||||
}
|
||||
|
||||
let mut write_handle = writer.lock().expect("Can lock writer");
|
||||
write_handle.append(RoutingTableOplogEntry::NoRouteExpired(subnet));
|
||||
write_handle.flush();
|
||||
});
|
||||
}
|
||||
|
||||
let mut write_handle = self.writer.lock().expect("Can lock writer");
|
||||
write_handle.append(RoutingTableOplogEntry::Queried(
|
||||
subnet,
|
||||
Arc::new(SubnetEntry::Queried { query_timeout }),
|
||||
));
|
||||
write_handle.flush();
|
||||
}
|
||||
}
|
||||
|
||||
pub struct RouteListReadGuard {
|
||||
inner: arc_swap::Guard<Arc<RouteList>>,
|
||||
}
|
||||
|
||||
impl Deref for RouteListReadGuard {
|
||||
type Target = RouteList;
|
||||
|
||||
fn deref(&self) -> &Self::Target {
|
||||
self.inner.deref()
|
||||
}
|
||||
}
|
||||
|
||||
/// A write guard over the [`RoutingTable`]. While this guard is held, updates won't be able to
|
||||
/// complete.
|
||||
pub struct RoutingTableWriteGuard<'a> {
|
||||
write_guard: MutexGuard<'a, left_right::WriteHandle<RoutingTableInner, RoutingTableOplogEntry>>,
|
||||
read_guard: left_right::ReadGuard<'a, RoutingTableInner>,
|
||||
expired_route_entry_sink: mpsc::Sender<RouteKey>,
|
||||
cancel_token: CancellationToken,
|
||||
}
|
||||
|
||||
impl<'a, 'b> RoutingTableWriteGuard<'a> {
|
||||
pub fn iter_mut(&'b mut self) -> RoutingTableIterMut<'a, 'b> {
|
||||
RoutingTableIterMut::new(
|
||||
&mut self.write_guard,
|
||||
self.read_guard.table.iter(),
|
||||
self.expired_route_entry_sink.clone(),
|
||||
self.cancel_token.clone(),
|
||||
)
|
||||
}
|
||||
}
|
||||
|
||||
impl Drop for RoutingTableWriteGuard<'_> {
|
||||
fn drop(&mut self) {
|
||||
self.write_guard.publish();
|
||||
}
|
||||
}
|
||||
|
||||
/// A read guard over the [`RoutingTable`]. While this guard is held, updates won't be able to
|
||||
/// complete.
|
||||
pub struct RoutingTableReadGuard<'a> {
|
||||
guard: left_right::ReadGuard<'a, RoutingTableInner>,
|
||||
}
|
||||
|
||||
impl RoutingTableReadGuard<'_> {
|
||||
pub fn iter(&self) -> RoutingTableIter {
|
||||
RoutingTableIter::new(self.guard.table.iter())
|
||||
}
|
||||
|
||||
/// Create an iterator for all queried subnets in the routing table
|
||||
pub fn iter_queries(&self) -> RoutingTableQueryIter {
|
||||
RoutingTableQueryIter::new(self.guard.table.iter())
|
||||
}
|
||||
|
||||
/// Create an iterator for all subnets which are currently marked as `NoRoute` in the routing
|
||||
/// table.
|
||||
pub fn iter_no_route(&self) -> RoutingTableNoRouteIter {
|
||||
RoutingTableNoRouteIter::new(self.guard.table.iter())
|
||||
}
|
||||
}
|
||||
|
||||
impl WriteGuard<'_> {
|
||||
/// Loads the current [`RouteList`].
|
||||
#[inline]
|
||||
pub fn routes(&self) -> RouteListReadGuard {
|
||||
let SubnetEntry::Exists { list } = &*self.value else {
|
||||
panic!("Write guard for non-route SubnetEntry")
|
||||
};
|
||||
RouteListReadGuard { inner: list.load() }
|
||||
}
|
||||
|
||||
/// Get mutable access to the [`RouteList`]. This will update the [`RouteList`] in place
|
||||
/// without locking the [`RoutingTable`].
|
||||
// TODO: Proper abstractions
|
||||
pub fn update_routes<
|
||||
F: FnMut(&mut RouteList, &mpsc::Sender<RouteKey>, &CancellationToken) -> bool,
|
||||
>(
|
||||
&mut self,
|
||||
mut op: F,
|
||||
) -> bool {
|
||||
let mut res = false;
|
||||
let mut delete = false;
|
||||
if let SubnetEntry::Exists { list } = &*self.value {
|
||||
list.rcu(|rl| {
|
||||
let mut new_val = rl.clone();
|
||||
let v = Arc::make_mut(&mut new_val);
|
||||
|
||||
res = op(v, &self.expired_route_entry_sink, &self.cancellation_token);
|
||||
delete = v.is_empty();
|
||||
|
||||
new_val
|
||||
});
|
||||
|
||||
if delete && self.exists {
|
||||
trace!(subnet = %self.subnet, "Deleting subnet which became empty after updating");
|
||||
let mut writer = self.routing_table.writer.lock().unwrap();
|
||||
|
||||
writer.append(RoutingTableOplogEntry::Delete(self.subnet));
|
||||
writer.publish();
|
||||
}
|
||||
|
||||
res
|
||||
} else {
|
||||
false
|
||||
}
|
||||
}
|
||||
|
||||
/// Set the [`RouteEntry`] with the given [`neighbour`](Peer) as the selected route.
|
||||
pub fn set_selected(&mut self, neighbour: &Peer) {
|
||||
if let SubnetEntry::Exists { list } = &*self.value {
|
||||
list.rcu(|routes| {
|
||||
let mut new_routes = routes.clone();
|
||||
let routes = Arc::make_mut(&mut new_routes);
|
||||
let Some(pos) = routes.iter().position(|re| re.neighbour() == neighbour) else {
|
||||
error!(
|
||||
neighbour = neighbour.connection_identifier(),
|
||||
"Failed to select route entry with given route key, no such entry"
|
||||
);
|
||||
return new_routes;
|
||||
};
|
||||
|
||||
// We don't need a check for an empty list here, since we found a selected route there
|
||||
// _MUST_ be at least 1 entry.
|
||||
// Set the first element to unselected, then select the proper element so this also works
|
||||
// in case the existing route is "reselected".
|
||||
routes[0].set_selected(false);
|
||||
routes[pos].set_selected(true);
|
||||
routes.swap(0, pos);
|
||||
|
||||
new_routes
|
||||
});
|
||||
}
|
||||
}
|
||||
|
||||
/// Unconditionally unselects the selected route, if one is present.
|
||||
///
|
||||
/// In case no route is selected, this is a no-op.
|
||||
pub fn unselect(&mut self) {
|
||||
if let SubnetEntry::Exists { list } = &*self.value {
|
||||
list.rcu(|v| {
|
||||
let mut new_val = v.clone();
|
||||
let new_ref = Arc::make_mut(&mut new_val);
|
||||
|
||||
if let Some(e) = new_ref.get_mut(0) {
|
||||
e.set_selected(false);
|
||||
}
|
||||
|
||||
new_val
|
||||
});
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl Drop for WriteGuard<'_> {
|
||||
fn drop(&mut self) {
|
||||
// FIXME: try to get rid of clones on the Arc here
|
||||
if let SubnetEntry::Exists { list } = &*self.value {
|
||||
let value = list.load();
|
||||
match self.exists {
|
||||
// The route list did not exist, and now it is not empty, so an entry was added. We
|
||||
// need to add the route list to the routing table.
|
||||
false if !value.is_empty() => {
|
||||
trace!(subnet = %self.subnet, "Inserting new route list for subnet");
|
||||
let mut writer = self.routing_table.writer.lock().unwrap();
|
||||
writer.append(RoutingTableOplogEntry::Upsert(
|
||||
self.subnet,
|
||||
Arc::clone(&self.value),
|
||||
));
|
||||
writer.publish();
|
||||
}
|
||||
// There was an existing route list which is now empty, so the entry for this subnet
|
||||
// needs to be deleted in the routing table.
|
||||
true if value.is_empty() => {
|
||||
trace!(subnet = %self.subnet, "Removing route list for subnet");
|
||||
let mut writer = self.routing_table.writer.lock().unwrap();
|
||||
writer.append(RoutingTableOplogEntry::Delete(self.subnet));
|
||||
writer.publish();
|
||||
}
|
||||
// Nothing to do in these cases. Either no value was inserted in a non existing
|
||||
// routelist, or an existing one was updated in place.
|
||||
_ => {}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
/// Operations allowed on the left_right for the routing table.
|
||||
enum RoutingTableOplogEntry {
|
||||
/// Insert or Update the value for the given subnet.
|
||||
Upsert(Subnet, Arc<SubnetEntry>),
|
||||
/// Mark a subnet as queried.
|
||||
Queried(Subnet, Arc<SubnetEntry>),
|
||||
/// Delete the entry for the given subnet.
|
||||
Delete(Subnet),
|
||||
/// The route request for a subnet expired, if it is still in query state mark it as not
|
||||
/// existing
|
||||
QueryExpired(Subnet, Arc<SubnetEntry>),
|
||||
/// The marker for explicitly not having a route to a subnet has expired
|
||||
NoRouteExpired(Subnet),
|
||||
}
|
||||
|
||||
/// Convert an [`IpAddr`] into an [`Ipv6Addr`]. Panics if the contained addrss is not an IPv6
|
||||
/// address.
|
||||
fn expect_ipv6(ip: IpAddr) -> Ipv6Addr {
|
||||
let IpAddr::V6(ip) = ip else {
|
||||
panic!("Expected ipv6 address")
|
||||
};
|
||||
|
||||
ip
|
||||
}
|
||||
|
||||
impl left_right::Absorb<RoutingTableOplogEntry> for RoutingTableInner {
|
||||
fn absorb_first(&mut self, operation: &mut RoutingTableOplogEntry, _other: &Self) {
|
||||
match operation {
|
||||
RoutingTableOplogEntry::Upsert(subnet, list) => {
|
||||
self.table.insert(
|
||||
expect_ipv6(subnet.address()),
|
||||
subnet.prefix_len().into(),
|
||||
Arc::clone(list),
|
||||
);
|
||||
}
|
||||
|
||||
RoutingTableOplogEntry::Queried(subnet, se) => {
|
||||
// Mark a query only if we don't have a valid entry
|
||||
let entry = self
|
||||
.table
|
||||
.exact_match(expect_ipv6(subnet.address()), subnet.prefix_len().into())
|
||||
.map(Arc::deref);
|
||||
|
||||
// If we have no route, transition to query, if we have a route or existing query,
|
||||
// do nothing
|
||||
if matches!(entry, None | Some(SubnetEntry::NoRoute { .. })) {
|
||||
self.table.insert(
|
||||
expect_ipv6(subnet.address()),
|
||||
subnet.prefix_len().into(),
|
||||
Arc::clone(se),
|
||||
);
|
||||
}
|
||||
}
|
||||
RoutingTableOplogEntry::Delete(subnet) => {
|
||||
self.table
|
||||
.remove(expect_ipv6(subnet.address()), subnet.prefix_len().into());
|
||||
}
|
||||
RoutingTableOplogEntry::QueryExpired(subnet, nre) => {
|
||||
if let Some(entry) = self
|
||||
.table
|
||||
.exact_match(expect_ipv6(subnet.address()), subnet.prefix_len().into())
|
||||
{
|
||||
if let SubnetEntry::Queried { .. } = &**entry {
|
||||
self.table.insert(
|
||||
expect_ipv6(subnet.address()),
|
||||
subnet.prefix_len().into(),
|
||||
Arc::clone(nre),
|
||||
);
|
||||
}
|
||||
}
|
||||
}
|
||||
RoutingTableOplogEntry::NoRouteExpired(subnet) => {
|
||||
if let Some(entry) = self
|
||||
.table
|
||||
.exact_match(expect_ipv6(subnet.address()), subnet.prefix_len().into())
|
||||
{
|
||||
if let SubnetEntry::NoRoute { .. } = &**entry {
|
||||
self.table
|
||||
.remove(expect_ipv6(subnet.address()), subnet.prefix_len().into());
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
fn sync_with(&mut self, first: &Self) {
|
||||
for (k, ss, v) in first.table.iter() {
|
||||
self.table.insert(k, ss, v.clone());
|
||||
}
|
||||
}
|
||||
|
||||
fn absorb_second(&mut self, operation: RoutingTableOplogEntry, _: &Self) {
|
||||
match operation {
|
||||
RoutingTableOplogEntry::Upsert(subnet, list) => {
|
||||
self.table.insert(
|
||||
expect_ipv6(subnet.address()),
|
||||
subnet.prefix_len().into(),
|
||||
list,
|
||||
);
|
||||
}
|
||||
RoutingTableOplogEntry::Queried(subnet, se) => {
|
||||
// Mark a query only if we don't have a valid entry
|
||||
let entry = self
|
||||
.table
|
||||
.exact_match(expect_ipv6(subnet.address()), subnet.prefix_len().into())
|
||||
.map(Arc::deref);
|
||||
|
||||
// If we have no route, transition to query, if we have a route or existing query,
|
||||
// do nothing
|
||||
if matches!(entry, None | Some(SubnetEntry::NoRoute { .. })) {
|
||||
self.table.insert(
|
||||
expect_ipv6(subnet.address()),
|
||||
subnet.prefix_len().into(),
|
||||
se,
|
||||
);
|
||||
}
|
||||
}
|
||||
RoutingTableOplogEntry::Delete(subnet) => {
|
||||
self.table
|
||||
.remove(expect_ipv6(subnet.address()), subnet.prefix_len().into());
|
||||
}
|
||||
RoutingTableOplogEntry::QueryExpired(subnet, nre) => {
|
||||
if let Some(entry) = self
|
||||
.table
|
||||
.exact_match(expect_ipv6(subnet.address()), subnet.prefix_len().into())
|
||||
{
|
||||
if let SubnetEntry::Queried { .. } = &**entry {
|
||||
self.table.insert(
|
||||
expect_ipv6(subnet.address()),
|
||||
subnet.prefix_len().into(),
|
||||
nre,
|
||||
);
|
||||
}
|
||||
}
|
||||
}
|
||||
RoutingTableOplogEntry::NoRouteExpired(subnet) => {
|
||||
if let Some(entry) = self
|
||||
.table
|
||||
.exact_match(expect_ipv6(subnet.address()), subnet.prefix_len().into())
|
||||
{
|
||||
if let SubnetEntry::NoRoute { .. } = &**entry {
|
||||
self.table
|
||||
.remove(expect_ipv6(subnet.address()), subnet.prefix_len().into());
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl Drop for RoutingTableShared {
|
||||
fn drop(&mut self) {
|
||||
self.cancel_token.cancel();
|
||||
}
|
||||
}
|
||||
100
mycelium/src/routing_table/iter.rs
Normal file
100
mycelium/src/routing_table/iter.rs
Normal file
@@ -0,0 +1,100 @@
|
||||
use std::{net::Ipv6Addr, sync::Arc};
|
||||
|
||||
use crate::subnet::Subnet;
|
||||
|
||||
use super::{subnet_entry::SubnetEntry, NoRouteSubnet, QueriedSubnet, RouteListReadGuard};
|
||||
|
||||
/// An iterator over a [`routing table`](super::RoutingTable) giving read only access to
|
||||
/// [`RouteList`]'s.
|
||||
pub struct RoutingTableIter<'a>(
|
||||
ip_network_table_deps_treebitmap::Iter<'a, Ipv6Addr, Arc<SubnetEntry>>,
|
||||
);
|
||||
|
||||
impl<'a> RoutingTableIter<'a> {
|
||||
/// Create a new `RoutingTableIter` which will iterate over all entries in a [`RoutingTable`].
|
||||
pub(super) fn new(
|
||||
inner: ip_network_table_deps_treebitmap::Iter<'a, Ipv6Addr, Arc<SubnetEntry>>,
|
||||
) -> Self {
|
||||
Self(inner)
|
||||
}
|
||||
}
|
||||
|
||||
impl Iterator for RoutingTableIter<'_> {
|
||||
type Item = (Subnet, RouteListReadGuard);
|
||||
|
||||
fn next(&mut self) -> Option<Self::Item> {
|
||||
for (ip, prefix_size, rl) in self.0.by_ref() {
|
||||
if let SubnetEntry::Exists { list } = &**rl {
|
||||
return Some((
|
||||
Subnet::new(ip.into(), prefix_size as u8)
|
||||
.expect("Routing table contains valid subnets"),
|
||||
RouteListReadGuard { inner: list.load() },
|
||||
));
|
||||
}
|
||||
}
|
||||
None
|
||||
}
|
||||
}
|
||||
|
||||
/// Iterator over queried routes in the routing table.
|
||||
pub struct RoutingTableQueryIter<'a>(
|
||||
ip_network_table_deps_treebitmap::Iter<'a, Ipv6Addr, Arc<SubnetEntry>>,
|
||||
);
|
||||
|
||||
impl<'a> RoutingTableQueryIter<'a> {
|
||||
/// Create a new `RoutingTableQueryIter` which will iterate over all queried entries in a [`RoutingTable`].
|
||||
pub(super) fn new(
|
||||
inner: ip_network_table_deps_treebitmap::Iter<'a, Ipv6Addr, Arc<SubnetEntry>>,
|
||||
) -> Self {
|
||||
Self(inner)
|
||||
}
|
||||
}
|
||||
|
||||
impl Iterator for RoutingTableQueryIter<'_> {
|
||||
type Item = QueriedSubnet;
|
||||
|
||||
fn next(&mut self) -> Option<Self::Item> {
|
||||
for (ip, prefix_size, rl) in self.0.by_ref() {
|
||||
if let SubnetEntry::Queried { query_timeout } = &**rl {
|
||||
return Some(QueriedSubnet::new(
|
||||
Subnet::new(ip.into(), prefix_size as u8)
|
||||
.expect("Routing table contains valid subnets"),
|
||||
*query_timeout,
|
||||
));
|
||||
}
|
||||
}
|
||||
None
|
||||
}
|
||||
}
|
||||
|
||||
/// Iterator for entries which are explicitly marked as "no route"in the routing table.
|
||||
pub struct RoutingTableNoRouteIter<'a>(
|
||||
ip_network_table_deps_treebitmap::Iter<'a, Ipv6Addr, Arc<SubnetEntry>>,
|
||||
);
|
||||
|
||||
impl<'a> RoutingTableNoRouteIter<'a> {
|
||||
/// Create a new `RoutingTableNoRouteIter` which will iterate over all entries in a [`RoutingTable`]
|
||||
/// which are explicitly marked as `NoRoute`
|
||||
pub(super) fn new(
|
||||
inner: ip_network_table_deps_treebitmap::Iter<'a, Ipv6Addr, Arc<SubnetEntry>>,
|
||||
) -> Self {
|
||||
Self(inner)
|
||||
}
|
||||
}
|
||||
|
||||
impl Iterator for RoutingTableNoRouteIter<'_> {
|
||||
type Item = NoRouteSubnet;
|
||||
|
||||
fn next(&mut self) -> Option<Self::Item> {
|
||||
for (ip, prefix_size, rl) in self.0.by_ref() {
|
||||
if let SubnetEntry::NoRoute { expiry } = &**rl {
|
||||
return Some(NoRouteSubnet::new(
|
||||
Subnet::new(ip.into(), prefix_size as u8)
|
||||
.expect("Routing table contains valid subnets"),
|
||||
*expiry,
|
||||
));
|
||||
}
|
||||
}
|
||||
None
|
||||
}
|
||||
}
|
||||
107
mycelium/src/routing_table/iter_mut.rs
Normal file
107
mycelium/src/routing_table/iter_mut.rs
Normal file
@@ -0,0 +1,107 @@
|
||||
use tokio::sync::mpsc;
|
||||
use tokio_util::sync::CancellationToken;
|
||||
use tracing::trace;
|
||||
|
||||
use crate::subnet::Subnet;
|
||||
|
||||
use super::{
|
||||
subnet_entry::SubnetEntry, RouteKey, RouteList, RoutingTableInner, RoutingTableOplogEntry,
|
||||
};
|
||||
use std::{
|
||||
net::Ipv6Addr,
|
||||
sync::{Arc, MutexGuard},
|
||||
};
|
||||
|
||||
/// An iterator over a [`routing table`](super::RoutingTable), yielding mutable access to the
|
||||
/// entries in the table.
|
||||
pub struct RoutingTableIterMut<'a, 'b> {
|
||||
write_guard:
|
||||
&'b mut MutexGuard<'a, left_right::WriteHandle<RoutingTableInner, RoutingTableOplogEntry>>,
|
||||
iter: ip_network_table_deps_treebitmap::Iter<'b, Ipv6Addr, Arc<SubnetEntry>>,
|
||||
|
||||
expired_route_entry_sink: mpsc::Sender<RouteKey>,
|
||||
cancel_token: CancellationToken,
|
||||
}
|
||||
|
||||
impl<'a, 'b> RoutingTableIterMut<'a, 'b> {
|
||||
pub(super) fn new(
|
||||
write_guard: &'b mut MutexGuard<
|
||||
'a,
|
||||
left_right::WriteHandle<RoutingTableInner, RoutingTableOplogEntry>,
|
||||
>,
|
||||
iter: ip_network_table_deps_treebitmap::Iter<'b, Ipv6Addr, Arc<SubnetEntry>>,
|
||||
|
||||
expired_route_entry_sink: mpsc::Sender<RouteKey>,
|
||||
cancel_token: CancellationToken,
|
||||
) -> Self {
|
||||
Self {
|
||||
write_guard,
|
||||
iter,
|
||||
expired_route_entry_sink,
|
||||
cancel_token,
|
||||
}
|
||||
}
|
||||
|
||||
/// Get the next item in this iterator. This is not implemented as the [`Iterator`] trait,
|
||||
/// since we hand out items which are lifetime bound to this struct.
|
||||
pub fn next<'c>(&'c mut self) -> Option<(Subnet, RoutingTableIterMutEntry<'a, 'c>)> {
|
||||
for (ip, prefix_size, rl) in self.iter.by_ref() {
|
||||
if matches!(&**rl, SubnetEntry::Exists { .. }) {
|
||||
let subnet = Subnet::new(ip.into(), prefix_size as u8)
|
||||
.expect("Routing table contains valid subnets");
|
||||
return Some((
|
||||
subnet,
|
||||
RoutingTableIterMutEntry {
|
||||
writer: self.write_guard,
|
||||
store: Arc::clone(rl),
|
||||
subnet,
|
||||
expired_route_entry_sink: self.expired_route_entry_sink.clone(),
|
||||
cancellation_token: self.cancel_token.clone(),
|
||||
},
|
||||
));
|
||||
};
|
||||
}
|
||||
|
||||
None
|
||||
}
|
||||
}
|
||||
|
||||
/// A smart pointer giving mutable access to a [`RouteList`].
|
||||
pub struct RoutingTableIterMutEntry<'a, 'b> {
|
||||
writer:
|
||||
&'b mut MutexGuard<'a, left_right::WriteHandle<RoutingTableInner, RoutingTableOplogEntry>>,
|
||||
/// Owned copy of the RouteList, this is populated once mutable access the the RouteList has
|
||||
/// been requested.
|
||||
store: Arc<SubnetEntry>,
|
||||
/// The subnet we are writing to.
|
||||
subnet: Subnet,
|
||||
expired_route_entry_sink: mpsc::Sender<RouteKey>,
|
||||
cancellation_token: CancellationToken,
|
||||
}
|
||||
|
||||
impl RoutingTableIterMutEntry<'_, '_> {
|
||||
/// Updates the routes for this entry
|
||||
pub fn update_routes<F: FnMut(&mut RouteList, &mpsc::Sender<RouteKey>, &CancellationToken)>(
|
||||
&mut self,
|
||||
mut op: F,
|
||||
) {
|
||||
let mut delete = false;
|
||||
if let SubnetEntry::Exists { list } = &*self.store {
|
||||
list.rcu(|rl| {
|
||||
let mut new_val = rl.clone();
|
||||
let v = Arc::make_mut(&mut new_val);
|
||||
|
||||
op(v, &self.expired_route_entry_sink, &self.cancellation_token);
|
||||
delete = v.is_empty();
|
||||
|
||||
new_val
|
||||
});
|
||||
|
||||
if delete {
|
||||
trace!(subnet = %self.subnet, "Queue subnet for deletion since route list is now empty");
|
||||
self.writer
|
||||
.append(RoutingTableOplogEntry::Delete(self.subnet));
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
35
mycelium/src/routing_table/no_route.rs
Normal file
35
mycelium/src/routing_table/no_route.rs
Normal file
@@ -0,0 +1,35 @@
|
||||
use tokio::time::Instant;
|
||||
|
||||
use crate::subnet::Subnet;
|
||||
|
||||
/// Information about a [`subnet`](Subnet) which is currently marked as NoRoute.
|
||||
#[derive(Debug, Clone, Copy)]
|
||||
pub struct NoRouteSubnet {
|
||||
/// The subnet which has no route.
|
||||
subnet: Subnet,
|
||||
/// Time at which the entry expires. After this timeout expires, the entry is removed and a new
|
||||
/// query can be performed.
|
||||
entry_expires: Instant,
|
||||
}
|
||||
|
||||
impl NoRouteSubnet {
|
||||
/// Create a new `NoRouteSubnet` for the given [`subnet`](Subnet), expiring at the provided
|
||||
/// [`time`](Instant).
|
||||
pub fn new(subnet: Subnet, entry_expires: Instant) -> Self {
|
||||
Self {
|
||||
subnet,
|
||||
entry_expires,
|
||||
}
|
||||
}
|
||||
|
||||
/// The [`subnet`](Subnet) for which there is no route.
|
||||
pub fn subnet(&self) -> Subnet {
|
||||
self.subnet
|
||||
}
|
||||
|
||||
/// The moment this entry expires. Once this timeout expires, a new query can be launched for
|
||||
/// route discovery for this [`subnet`](Subnet).
|
||||
pub fn entry_expires(&self) -> Instant {
|
||||
self.entry_expires
|
||||
}
|
||||
}
|
||||
35
mycelium/src/routing_table/queried_subnet.rs
Normal file
35
mycelium/src/routing_table/queried_subnet.rs
Normal file
@@ -0,0 +1,35 @@
|
||||
use tokio::time::Instant;
|
||||
|
||||
use crate::subnet::Subnet;
|
||||
|
||||
/// Information about a [`subnet`](Subnet) which is currently in the queried state
|
||||
#[derive(Debug, Clone, Copy)]
|
||||
pub struct QueriedSubnet {
|
||||
/// The subnet which was queried.
|
||||
subnet: Subnet,
|
||||
/// Time at which the query expires. If no feasible updates come in before this, the subnet is
|
||||
/// marked as no route temporarily.
|
||||
query_expires: Instant,
|
||||
}
|
||||
|
||||
impl QueriedSubnet {
|
||||
/// Create a new `QueriedSubnet` for the given [`subnet`](Subnet), expiring at the provided
|
||||
/// [`time`](Instant).
|
||||
pub fn new(subnet: Subnet, query_expires: Instant) -> Self {
|
||||
Self {
|
||||
subnet,
|
||||
query_expires,
|
||||
}
|
||||
}
|
||||
|
||||
/// The [`subnet`](Subnet) being queried.
|
||||
pub fn subnet(&self) -> Subnet {
|
||||
self.subnet
|
||||
}
|
||||
|
||||
/// The moment this query expires. If no route is discovered before this, the [`subnet`](Subnet)
|
||||
/// is marked as no route temporarily.
|
||||
pub fn query_expires(&self) -> Instant {
|
||||
self.query_expires
|
||||
}
|
||||
}
|
||||
120
mycelium/src/routing_table/route_entry.rs
Normal file
120
mycelium/src/routing_table/route_entry.rs
Normal file
@@ -0,0 +1,120 @@
|
||||
use tokio::time::Instant;
|
||||
|
||||
use crate::{
|
||||
metric::Metric, peer::Peer, router_id::RouterId, sequence_number::SeqNo,
|
||||
source_table::SourceKey,
|
||||
};
|
||||
|
||||
/// RouteEntry holds all relevant information about a specific route. Since this includes the next
|
||||
/// hop, a single subnet can have multiple route entries.
|
||||
#[derive(Clone)]
|
||||
pub struct RouteEntry {
|
||||
source: SourceKey,
|
||||
neighbour: Peer,
|
||||
metric: Metric,
|
||||
seqno: SeqNo,
|
||||
selected: bool,
|
||||
expires: Instant,
|
||||
}
|
||||
|
||||
impl RouteEntry {
|
||||
/// Create a new `RouteEntry` with the provided values.
|
||||
pub fn new(
|
||||
source: SourceKey,
|
||||
neighbour: Peer,
|
||||
metric: Metric,
|
||||
seqno: SeqNo,
|
||||
selected: bool,
|
||||
expires: Instant,
|
||||
) -> Self {
|
||||
Self {
|
||||
source,
|
||||
neighbour,
|
||||
metric,
|
||||
seqno,
|
||||
selected,
|
||||
expires,
|
||||
}
|
||||
}
|
||||
|
||||
/// Return the [`SourceKey`] for this `RouteEntry`.
|
||||
pub fn source(&self) -> SourceKey {
|
||||
self.source
|
||||
}
|
||||
|
||||
/// Return the [`neighbour`](Peer) used as next hop for this `RouteEntry`.
|
||||
pub fn neighbour(&self) -> &Peer {
|
||||
&self.neighbour
|
||||
}
|
||||
|
||||
/// Return the [`Metric`] of this `RouteEntry`.
|
||||
pub fn metric(&self) -> Metric {
|
||||
self.metric
|
||||
}
|
||||
|
||||
/// Return the [`sequence number`](SeqNo) for the `RouteEntry`.
|
||||
pub fn seqno(&self) -> SeqNo {
|
||||
self.seqno
|
||||
}
|
||||
|
||||
/// Return if this [`RouteEntry`] is selected.
|
||||
pub fn selected(&self) -> bool {
|
||||
self.selected
|
||||
}
|
||||
|
||||
/// Return the [`Instant`] when this `RouteEntry` expires if it doesn't get updated before
|
||||
/// then.
|
||||
pub fn expires(&self) -> Instant {
|
||||
self.expires
|
||||
}
|
||||
|
||||
/// Set the [`SourceKey`] for this `RouteEntry`.
|
||||
pub fn set_source(&mut self, source: SourceKey) {
|
||||
self.source = source;
|
||||
}
|
||||
|
||||
/// Set the [`RouterId`] for this `RouteEntry`.
|
||||
pub fn set_router_id(&mut self, router_id: RouterId) {
|
||||
self.source.set_router_id(router_id)
|
||||
}
|
||||
|
||||
/// Sets the [`neighbour`](Peer) for this `RouteEntry`.
|
||||
pub fn set_neighbour(&mut self, neighbour: Peer) {
|
||||
self.neighbour = neighbour;
|
||||
}
|
||||
|
||||
/// Sets the [`Metric`] for this `RouteEntry`.
|
||||
pub fn set_metric(&mut self, metric: Metric) {
|
||||
self.metric = metric;
|
||||
}
|
||||
|
||||
/// Sets the [`sequence number`](SeqNo) for this `RouteEntry`.
|
||||
pub fn set_seqno(&mut self, seqno: SeqNo) {
|
||||
self.seqno = seqno;
|
||||
}
|
||||
|
||||
/// Sets if this `RouteEntry` is the selected route for the associated
|
||||
/// [`Subnet`](crate::subnet::Subnet).
|
||||
pub fn set_selected(&mut self, selected: bool) {
|
||||
self.selected = selected;
|
||||
}
|
||||
|
||||
/// Sets the expiration time for this [`RouteEntry`].
|
||||
pub(super) fn set_expires(&mut self, expires: Instant) {
|
||||
self.expires = expires;
|
||||
}
|
||||
}
|
||||
|
||||
// Manual Debug implementation since SharedSecret is explicitly not Debug
|
||||
impl std::fmt::Debug for RouteEntry {
|
||||
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
|
||||
f.debug_struct("RouteEntry")
|
||||
.field("source", &self.source)
|
||||
.field("neighbour", &self.neighbour)
|
||||
.field("metric", &self.metric)
|
||||
.field("seqno", &self.seqno)
|
||||
.field("selected", &self.selected)
|
||||
.field("expires", &self.expires)
|
||||
.finish()
|
||||
}
|
||||
}
|
||||
38
mycelium/src/routing_table/route_key.rs
Normal file
38
mycelium/src/routing_table/route_key.rs
Normal file
@@ -0,0 +1,38 @@
|
||||
use crate::{peer::Peer, subnet::Subnet};
|
||||
|
||||
/// RouteKey uniquely defines a route via a peer.
|
||||
#[derive(Debug, Clone, PartialEq)]
|
||||
pub struct RouteKey {
|
||||
subnet: Subnet,
|
||||
neighbour: Peer,
|
||||
}
|
||||
|
||||
impl RouteKey {
|
||||
/// Creates a new `RouteKey` for the given [`Subnet`] and [`neighbour`](Peer).
|
||||
#[inline]
|
||||
pub fn new(subnet: Subnet, neighbour: Peer) -> Self {
|
||||
Self { subnet, neighbour }
|
||||
}
|
||||
|
||||
/// Get's the [`Subnet`] identified by this `RouteKey`.
|
||||
#[inline]
|
||||
pub fn subnet(&self) -> Subnet {
|
||||
self.subnet
|
||||
}
|
||||
|
||||
/// Gets the [`neighbour`](Peer) identified by this `RouteKey`.
|
||||
#[inline]
|
||||
pub fn neighbour(&self) -> &Peer {
|
||||
&self.neighbour
|
||||
}
|
||||
}
|
||||
|
||||
impl std::fmt::Display for RouteKey {
|
||||
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
|
||||
f.write_fmt(format_args!(
|
||||
"{} via {}",
|
||||
self.subnet,
|
||||
self.neighbour.connection_identifier()
|
||||
))
|
||||
}
|
||||
}
|
||||
201
mycelium/src/routing_table/route_list.rs
Normal file
201
mycelium/src/routing_table/route_list.rs
Normal file
@@ -0,0 +1,201 @@
|
||||
use std::{
|
||||
ops::{Deref, DerefMut, Index, IndexMut},
|
||||
sync::Arc,
|
||||
};
|
||||
|
||||
use tokio::sync::mpsc;
|
||||
use tokio_util::sync::CancellationToken;
|
||||
use tracing::{debug, error};
|
||||
|
||||
use crate::{crypto::SharedSecret, peer::Peer, task::AbortHandle};
|
||||
|
||||
use super::{RouteEntry, RouteKey};
|
||||
|
||||
/// The RouteList holds all routes for a specific subnet.
|
||||
// By convention, if a route is selected, it will always be at index 0 in the list.
|
||||
#[derive(Clone)]
|
||||
pub struct RouteList {
|
||||
list: Vec<(Arc<AbortHandle>, RouteEntry)>,
|
||||
shared_secret: SharedSecret,
|
||||
}
|
||||
|
||||
impl RouteList {
|
||||
/// Create a new empty RouteList
|
||||
pub(crate) fn new(shared_secret: SharedSecret) -> Self {
|
||||
Self {
|
||||
list: Vec::new(),
|
||||
shared_secret,
|
||||
}
|
||||
}
|
||||
|
||||
/// Returns the [`SharedSecret`] used for encryption of packets to and from the associated
|
||||
/// [`Subnet`].
|
||||
#[inline]
|
||||
pub fn shared_secret(&self) -> &SharedSecret {
|
||||
&self.shared_secret
|
||||
}
|
||||
|
||||
/// Checks if there are any actual routes in the list.
|
||||
#[inline]
|
||||
pub fn is_empty(&self) -> bool {
|
||||
self.list.is_empty()
|
||||
}
|
||||
|
||||
/// Returns the selected route for the [`Subnet`] this is the `RouteList` for, if one exists.
|
||||
pub fn selected(&self) -> Option<&RouteEntry> {
|
||||
self.list
|
||||
.first()
|
||||
.map(|(_, re)| re)
|
||||
.and_then(|re| if re.selected() { Some(re) } else { None })
|
||||
}
|
||||
|
||||
/// Returns an iterator over the `RouteList`.
|
||||
///
|
||||
/// The iterator yields all [`route entries`](RouteEntry) in the list.
|
||||
pub fn iter(&self) -> RouteListIter {
|
||||
RouteListIter::new(self)
|
||||
}
|
||||
|
||||
/// Returns an iterator over the `RouteList` yielding mutable access to the elements.
|
||||
///
|
||||
/// The iterator yields all [`route entries`](RouteEntry) in the list.
|
||||
pub fn iter_mut(&mut self) -> impl Iterator<Item = RouteGuard> {
|
||||
self.list.iter_mut().map(|item| RouteGuard { item })
|
||||
}
|
||||
|
||||
/// Removes a [`RouteEntry`] from the `RouteList`.
|
||||
///
|
||||
/// This does nothing if the neighbour does not exist.
|
||||
pub fn remove(&mut self, neighbour: &Peer) {
|
||||
let Some(pos) = self
|
||||
.list
|
||||
.iter()
|
||||
.position(|re| re.1.neighbour() == neighbour)
|
||||
else {
|
||||
return;
|
||||
};
|
||||
|
||||
let old = self.list.swap_remove(pos);
|
||||
old.0.abort();
|
||||
}
|
||||
|
||||
/// Swaps the position of 2 `RouteEntry`s in the route list.
|
||||
pub fn swap(&mut self, first: usize, second: usize) {
|
||||
self.list.swap(first, second)
|
||||
}
|
||||
|
||||
pub fn get_mut(&mut self, index: usize) -> Option<&mut RouteEntry> {
|
||||
self.list.get_mut(index).map(|(_, re)| re)
|
||||
}
|
||||
|
||||
/// Insert a new [`RouteEntry`] in the `RouteList`.
|
||||
pub fn insert(
|
||||
&mut self,
|
||||
re: RouteEntry,
|
||||
expired_route_entry_sink: mpsc::Sender<RouteKey>,
|
||||
cancellation_token: CancellationToken,
|
||||
) {
|
||||
let expiration = re.expires();
|
||||
let rk = RouteKey::new(re.source().subnet(), re.neighbour().clone());
|
||||
let abort_handle = Arc::new(
|
||||
tokio::spawn(async move {
|
||||
tokio::select! {
|
||||
_ = cancellation_token.cancelled() => {}
|
||||
_ = tokio::time::sleep_until(expiration) => {
|
||||
debug!(route_key = %rk, "Expired route entry for route key");
|
||||
if let Err(e) = expired_route_entry_sink.send(rk).await {
|
||||
error!(route_key = %e.0, "Failed to send expired route key on cleanup channel");
|
||||
}
|
||||
}
|
||||
}
|
||||
})
|
||||
.abort_handle().into(),
|
||||
);
|
||||
|
||||
self.list.push((abort_handle, re));
|
||||
}
|
||||
}
|
||||
|
||||
pub struct RouteGuard<'a> {
|
||||
item: &'a mut (Arc<AbortHandle>, RouteEntry),
|
||||
}
|
||||
|
||||
impl Deref for RouteGuard<'_> {
|
||||
type Target = RouteEntry;
|
||||
|
||||
fn deref(&self) -> &Self::Target {
|
||||
&self.item.1
|
||||
}
|
||||
}
|
||||
|
||||
impl DerefMut for RouteGuard<'_> {
|
||||
fn deref_mut(&mut self) -> &mut Self::Target {
|
||||
&mut self.item.1
|
||||
}
|
||||
}
|
||||
|
||||
impl RouteGuard<'_> {
|
||||
pub fn set_expires(
|
||||
&mut self,
|
||||
expires: tokio::time::Instant,
|
||||
expired_route_entry_sink: mpsc::Sender<RouteKey>,
|
||||
cancellation_token: CancellationToken,
|
||||
) {
|
||||
let re = &mut self.item.1;
|
||||
re.set_expires(expires);
|
||||
let expiration = re.expires();
|
||||
let rk = RouteKey::new(re.source().subnet(), re.neighbour().clone());
|
||||
let abort_handle = Arc::new(
|
||||
tokio::spawn(async move {
|
||||
tokio::select! {
|
||||
_ = cancellation_token.cancelled() => {}
|
||||
_ = tokio::time::sleep_until(expiration) => {
|
||||
debug!(route_key = %rk, "Expired route entry for route key");
|
||||
if let Err(e) = expired_route_entry_sink.send(rk).await {
|
||||
error!(route_key = %e.0, "Failed to send expired route key on cleanup channel");
|
||||
}
|
||||
}
|
||||
}
|
||||
})
|
||||
.abort_handle().into(),
|
||||
);
|
||||
|
||||
self.item.0.abort();
|
||||
self.item.0 = abort_handle;
|
||||
}
|
||||
}
|
||||
|
||||
impl Index<usize> for RouteList {
|
||||
type Output = RouteEntry;
|
||||
|
||||
fn index(&self, index: usize) -> &Self::Output {
|
||||
&self.list[index].1
|
||||
}
|
||||
}
|
||||
|
||||
impl IndexMut<usize> for RouteList {
|
||||
fn index_mut(&mut self, index: usize) -> &mut Self::Output {
|
||||
&mut self.list[index].1
|
||||
}
|
||||
}
|
||||
|
||||
pub struct RouteListIter<'a> {
|
||||
route_list: &'a RouteList,
|
||||
idx: usize,
|
||||
}
|
||||
|
||||
impl<'a> RouteListIter<'a> {
|
||||
/// Create a new `RouteListIter` which will iterate over the given [`RouteList`].
|
||||
fn new(route_list: &'a RouteList) -> Self {
|
||||
Self { route_list, idx: 0 }
|
||||
}
|
||||
}
|
||||
|
||||
impl<'a> Iterator for RouteListIter<'a> {
|
||||
type Item = &'a RouteEntry;
|
||||
|
||||
fn next(&mut self) -> Option<Self::Item> {
|
||||
self.idx += 1;
|
||||
self.route_list.list.get(self.idx - 1).map(|(_, re)| re)
|
||||
}
|
||||
}
|
||||
16
mycelium/src/routing_table/subnet_entry.rs
Normal file
16
mycelium/src/routing_table/subnet_entry.rs
Normal file
@@ -0,0 +1,16 @@
|
||||
use arc_swap::ArcSwap;
|
||||
|
||||
use super::RouteList;
|
||||
|
||||
/// An entry for a [Subnet](crate::subnet::Subnet) in the routing table.
|
||||
#[allow(dead_code)]
|
||||
pub enum SubnetEntry {
|
||||
/// Routes for the given subnet exist
|
||||
Exists { list: ArcSwap<RouteList> },
|
||||
/// Routes are being queried from peers for the given subnet, but we haven't gotten a response
|
||||
/// yet
|
||||
Queried { query_timeout: tokio::time::Instant },
|
||||
/// We queried our peers for the subnet, but we didn't get a valid response in time, so there
|
||||
/// is for sure no route to the subnet.
|
||||
NoRoute { expiry: tokio::time::Instant },
|
||||
}
|
||||
103
mycelium/src/rr_cache.rs
Normal file
103
mycelium/src/rr_cache.rs
Normal file
@@ -0,0 +1,103 @@
|
||||
//! This module contains a cache implementation for route requests
|
||||
|
||||
use std::{
|
||||
net::{IpAddr, Ipv6Addr},
|
||||
sync::Arc,
|
||||
};
|
||||
|
||||
use dashmap::DashMap;
|
||||
use tokio::time::{Duration, Instant};
|
||||
use tracing::trace;
|
||||
|
||||
use crate::{babel::RouteRequest, peer::Peer, subnet::Subnet, task::AbortHandle};
|
||||
|
||||
/// Clean the route request cache every 5 seconds
|
||||
const CACHE_CLEANING_INTERVAL: Duration = Duration::from_secs(5);
|
||||
|
||||
/// IP used for the [`Subnet`] in the cache in case there is no prefix specified.
|
||||
const GLOBAL_SUBNET_IP: IpAddr = IpAddr::V6(Ipv6Addr::new(0, 0, 0, 0, 0, 0, 0, 0));
|
||||
|
||||
/// Prefix size to use for the [`Subnet`] in case there is no prefix specified.
|
||||
const GLOBAL_SUBNET_PREFIX_SIZE: u8 = 0;
|
||||
|
||||
/// A self cleaning cache for route requests.
|
||||
#[derive(Clone)]
|
||||
pub struct RouteRequestCache {
|
||||
/// The actual cache, mapping an instance of a route request to the peers which we've sent this
|
||||
/// to.
|
||||
cache: Arc<DashMap<Subnet, RouteRequestInfo, ahash::RandomState>>,
|
||||
|
||||
_cleanup_task: Arc<AbortHandle>,
|
||||
}
|
||||
|
||||
struct RouteRequestInfo {
|
||||
/// The lowest generation we've forwarded.
|
||||
generation: u8,
|
||||
/// Peers which we've sent this route request to already.
|
||||
receivers: Vec<Peer>,
|
||||
/// The moment we've sent this route request
|
||||
sent: Instant,
|
||||
}
|
||||
|
||||
impl RouteRequestCache {
|
||||
/// Create a new cache which cleans entries which are older than the given expiration.
|
||||
///
|
||||
/// The cache cleaning is done periodically, so entries might live slightly longer than the
|
||||
/// allowed expiration.
|
||||
pub fn new(expiration: Duration) -> Self {
|
||||
let cache = Arc::new(DashMap::with_hasher(ahash::RandomState::new()));
|
||||
|
||||
let _cleanup_task = Arc::new(
|
||||
tokio::spawn({
|
||||
let cache = cache.clone();
|
||||
async move {
|
||||
loop {
|
||||
tokio::time::sleep(CACHE_CLEANING_INTERVAL).await;
|
||||
|
||||
trace!("Cleaning route request cache");
|
||||
|
||||
cache.retain(|subnet, info: &mut RouteRequestInfo| {
|
||||
if info.sent.elapsed() < expiration {
|
||||
false
|
||||
} else {
|
||||
trace!(%subnet, "Removing exired route request from cache");
|
||||
true
|
||||
}
|
||||
});
|
||||
}
|
||||
}
|
||||
})
|
||||
.abort_handle()
|
||||
.into(),
|
||||
);
|
||||
|
||||
Self {
|
||||
cache,
|
||||
_cleanup_task,
|
||||
}
|
||||
}
|
||||
|
||||
/// Record a route request which has been sent to peers.
|
||||
pub fn sent_route_request(&self, rr: RouteRequest, receivers: Vec<Peer>) {
|
||||
let subnet = rr.prefix().unwrap_or(
|
||||
Subnet::new(GLOBAL_SUBNET_IP, GLOBAL_SUBNET_PREFIX_SIZE)
|
||||
.expect("Static global IPv6 subnet is valid; qed"),
|
||||
);
|
||||
let generation = rr.generation();
|
||||
|
||||
let rri = RouteRequestInfo {
|
||||
generation,
|
||||
receivers,
|
||||
sent: Instant::now(),
|
||||
};
|
||||
|
||||
self.cache.insert(subnet, rri);
|
||||
}
|
||||
|
||||
/// Get cached info about a route request for a subnet, if it exists.
|
||||
pub fn info(&self, subnet: Subnet) -> Option<(u8, Vec<Peer>)> {
|
||||
self.cache
|
||||
.get(&subnet)
|
||||
.map(|rri| (rri.generation, rri.receivers.clone()))
|
||||
}
|
||||
}
|
||||
176
mycelium/src/seqno_cache.rs
Normal file
176
mycelium/src/seqno_cache.rs
Normal file
@@ -0,0 +1,176 @@
|
||||
//! The seqno request cache keeps track of seqno requests sent by the node. This allows us to drop
|
||||
//! duplicate requests, and to notify the source of requests (if it wasn't the local node) about
|
||||
//! relevant updates.
|
||||
|
||||
use std::{
|
||||
sync::Arc,
|
||||
time::{Duration, Instant},
|
||||
};
|
||||
|
||||
use dashmap::DashMap;
|
||||
use tokio::time::MissedTickBehavior;
|
||||
use tracing::{debug, trace};
|
||||
|
||||
use crate::{peer::Peer, router_id::RouterId, sequence_number::SeqNo, subnet::Subnet};
|
||||
|
||||
/// The amount of time to remember a seqno request (since it was first seen), before we remove it
|
||||
/// (assuming it was not removed manually before that).
|
||||
const SEQNO_DEDUP_TTL: Duration = Duration::from_secs(60);
|
||||
|
||||
/// A sequence number request, either forwarded or originated by the local node.
|
||||
#[derive(Clone, Copy, PartialEq, Eq, Hash)]
|
||||
pub struct SeqnoRequestCacheKey {
|
||||
pub router_id: RouterId,
|
||||
pub subnet: Subnet,
|
||||
pub seqno: SeqNo,
|
||||
}
|
||||
|
||||
impl std::fmt::Display for SeqnoRequestCacheKey {
|
||||
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
|
||||
write!(
|
||||
f,
|
||||
"seqno {} for {} from {}",
|
||||
self.seqno, self.subnet, self.router_id
|
||||
)
|
||||
}
|
||||
}
|
||||
|
||||
/// Information retained for sequence number requests we've sent.
|
||||
struct SeqnoForwardInfo {
|
||||
/// Which peers have asked us to forward this seqno request.
|
||||
sources: Vec<Peer>,
|
||||
/// Which peers have we sent this request to.
|
||||
targets: Vec<Peer>,
|
||||
/// Time at which we first forwarded the requets.
|
||||
first_sent: Instant,
|
||||
/// When did we last sent a seqno request.
|
||||
last_sent: Instant,
|
||||
}
|
||||
|
||||
/// A cache for outbound seqno requests. Entries in the cache are automatically removed after a
|
||||
/// certain amount of time. The cache does not account for the source table. That is, if the
|
||||
/// requested seqno is smaller, it might pass the cache, but should have been blocked earlier by
|
||||
/// the source table check. As such, this cache should be the last step in deciding if a seqno
|
||||
/// request is forwarded.
|
||||
#[derive(Clone)]
|
||||
pub struct SeqnoCache {
|
||||
/// Actual cache wrapped in an Arc to make it sharaeble.
|
||||
cache: Arc<DashMap<SeqnoRequestCacheKey, SeqnoForwardInfo, ahash::RandomState>>,
|
||||
}
|
||||
|
||||
impl SeqnoCache {
|
||||
/// Create a new [`SeqnoCache`].
|
||||
pub fn new() -> Self {
|
||||
trace!(capacity = 0, "Creating new seqno cache");
|
||||
|
||||
let cache = Arc::new(DashMap::with_hasher_and_shard_amount(
|
||||
ahash::RandomState::new(),
|
||||
// This number has been chosen completely at random
|
||||
1024,
|
||||
));
|
||||
let sc = Self { cache };
|
||||
|
||||
// Spawn background cleanup task.
|
||||
tokio::spawn(sc.clone().sweep_entries());
|
||||
|
||||
sc
|
||||
}
|
||||
|
||||
/// Record a forwarded seqno request to a given target. Also keep track of the origin of the
|
||||
/// request. If the local node generated the request, source must be [`None`]
|
||||
pub fn forward(&self, request: SeqnoRequestCacheKey, target: Peer, source: Option<Peer>) {
|
||||
let mut info = self.cache.entry(request).or_default();
|
||||
info.last_sent = Instant::now();
|
||||
if !info.targets.contains(&target) {
|
||||
info.targets.push(target);
|
||||
} else {
|
||||
debug!(
|
||||
seqno_request = %request,
|
||||
"Already sent seqno request to target {}",
|
||||
target.connection_identifier()
|
||||
);
|
||||
}
|
||||
if let Some(source) = source {
|
||||
if !info.sources.contains(&source) {
|
||||
info.sources.push(source);
|
||||
} else {
|
||||
debug!(seqno_request = %request, "Peer {} is requesting the same seqno again", source.connection_identifier());
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
/// Get a list of all peers which we've already sent the given seqno request to, as well as
|
||||
/// when we've last sent a request.
|
||||
pub fn info(&self, request: &SeqnoRequestCacheKey) -> Option<(Instant, Vec<Peer>)> {
|
||||
self.cache
|
||||
.get(request)
|
||||
.map(|info| (info.last_sent, info.targets.clone()))
|
||||
}
|
||||
|
||||
/// Removes forwarding info from the seqno cache. If forwarding info is available, the source
|
||||
/// peers (peers which requested us to forward this request) are returned.
|
||||
// TODO: cleanup if needed
|
||||
#[allow(dead_code)]
|
||||
pub fn remove(&self, request: &SeqnoRequestCacheKey) -> Option<Vec<Peer>> {
|
||||
self.cache.remove(request).map(|(_, info)| info.sources)
|
||||
}
|
||||
|
||||
/// Get forwarding info from the seqno cache. If forwarding info is available, the source
|
||||
/// peers (peers which requested us to forward this request) are returned.
|
||||
// TODO: cleanup if needed
|
||||
#[allow(dead_code)]
|
||||
pub fn get(&self, request: &SeqnoRequestCacheKey) -> Option<Vec<Peer>> {
|
||||
self.cache.get(request).map(|info| info.sources.clone())
|
||||
}
|
||||
|
||||
/// Periodic task to clear old entries for which no reply came in.
|
||||
async fn sweep_entries(self) {
|
||||
let mut interval = tokio::time::interval(SEQNO_DEDUP_TTL);
|
||||
interval.set_missed_tick_behavior(MissedTickBehavior::Skip);
|
||||
|
||||
loop {
|
||||
interval.tick().await;
|
||||
|
||||
debug!("Cleaning up expired seqno requests from seqno cache");
|
||||
|
||||
let prev_entries = self.cache.len();
|
||||
let prev_cap = self.cache.capacity();
|
||||
self.cache
|
||||
.retain(|_, info| info.first_sent.elapsed() <= SEQNO_DEDUP_TTL);
|
||||
self.cache.shrink_to_fit();
|
||||
|
||||
debug!(
|
||||
cleaned_entries = prev_entries - self.cache.len(),
|
||||
removed_capacity = prev_cap - self.cache.capacity(),
|
||||
"Cleaned up stale seqno request cache entries"
|
||||
);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl Default for SeqnoCache {
|
||||
fn default() -> Self {
|
||||
Self::new()
|
||||
}
|
||||
}
|
||||
|
||||
impl Default for SeqnoForwardInfo {
|
||||
fn default() -> Self {
|
||||
Self {
|
||||
sources: vec![],
|
||||
targets: vec![],
|
||||
first_sent: Instant::now(),
|
||||
last_sent: Instant::now(),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl std::fmt::Debug for SeqnoRequestCacheKey {
|
||||
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
|
||||
f.debug_struct("SeqnoRequestCacheKey")
|
||||
.field("router_id", &self.router_id.to_string())
|
||||
.field("subnet", &self.subnet.to_string())
|
||||
.field("seqno", &self.seqno.to_string())
|
||||
.finish()
|
||||
}
|
||||
}
|
||||
153
mycelium/src/sequence_number.rs
Normal file
153
mycelium/src/sequence_number.rs
Normal file
@@ -0,0 +1,153 @@
|
||||
//! Dedicated logic for
|
||||
//! [sequence numbers](https://datatracker.ietf.org/doc/html/rfc8966#name-solving-starvation-sequenci).
|
||||
|
||||
use core::fmt;
|
||||
use core::ops::{Add, AddAssign};
|
||||
|
||||
/// This value is compared against when deciding if a `SeqNo` is larger or smaller, [as defined in
|
||||
/// the babel rfc](https://datatracker.ietf.org/doc/html/rfc8966#section-3.2.1).
|
||||
const SEQNO_COMPARE_TRESHOLD: u16 = 32_768;
|
||||
|
||||
/// A sequence number on a route.
|
||||
#[derive(Default, Debug, Clone, Copy, PartialEq, Eq, Hash)]
|
||||
pub struct SeqNo(u16);
|
||||
|
||||
impl SeqNo {
|
||||
/// Create a new `SeqNo` with the default value.
|
||||
pub fn new() -> Self {
|
||||
Self::default()
|
||||
}
|
||||
|
||||
/// Custom PartialOrd implementation as defined in [the babel rfc](https://datatracker.ietf.org/doc/html/rfc8966#section-3.2.1).
|
||||
/// Note that we don't implement the [`PartialOrd`](std::cmd::PartialOrd) trait, as the contract on
|
||||
/// that trait specifically defines that it is transitive, which is clearly not the case here.
|
||||
///
|
||||
/// There is a quirk in this equality comparison where values which are exactly 32_768 apart,
|
||||
/// will result in false in either way of ordering the arguments, which is counterintuitive to
|
||||
/// our understanding that a < b generally implies !(b < a).
|
||||
pub fn lt(&self, other: &Self) -> bool {
|
||||
if self.0 == other.0 {
|
||||
false
|
||||
} else {
|
||||
other.0.wrapping_sub(self.0) < SEQNO_COMPARE_TRESHOLD
|
||||
}
|
||||
}
|
||||
|
||||
/// Custom PartialOrd implementation as defined in [the babel rfc](https://datatracker.ietf.org/doc/html/rfc8966#section-3.2.1).
|
||||
/// Note that we don't implement the [`PartialOrd`](std::cmd::PartialOrd) trait, as the contract on
|
||||
/// that trait specifically defines that it is transitive, which is clearly not the case here.
|
||||
///
|
||||
/// There is a quirk in this equality comparison where values which are exactly 32_768 apart,
|
||||
/// will result in false in either way of ordering the arguments, which is counterintuitive to
|
||||
/// our understanding that a < b generally implies !(b < a).
|
||||
pub fn gt(&self, other: &Self) -> bool {
|
||||
if self.0 == other.0 {
|
||||
false
|
||||
} else {
|
||||
other.0.wrapping_sub(self.0) > SEQNO_COMPARE_TRESHOLD
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl fmt::Display for SeqNo {
|
||||
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
|
||||
f.write_fmt(format_args!("{}", self.0))
|
||||
}
|
||||
}
|
||||
|
||||
impl From<u16> for SeqNo {
|
||||
fn from(value: u16) -> Self {
|
||||
SeqNo(value)
|
||||
}
|
||||
}
|
||||
|
||||
impl From<SeqNo> for u16 {
|
||||
fn from(value: SeqNo) -> Self {
|
||||
value.0
|
||||
}
|
||||
}
|
||||
|
||||
impl Add<u16> for SeqNo {
|
||||
type Output = Self;
|
||||
|
||||
fn add(self, rhs: u16) -> Self::Output {
|
||||
SeqNo(self.0.wrapping_add(rhs))
|
||||
}
|
||||
}
|
||||
|
||||
impl AddAssign<u16> for SeqNo {
|
||||
fn add_assign(&mut self, rhs: u16) {
|
||||
*self = SeqNo(self.0.wrapping_add(rhs))
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use super::SeqNo;
|
||||
|
||||
#[test]
|
||||
fn cmp_eq_seqno() {
|
||||
let s1 = SeqNo::from(1);
|
||||
let s2 = SeqNo::from(1);
|
||||
assert_eq!(s1, s2);
|
||||
|
||||
let s1 = SeqNo::from(10_000);
|
||||
let s2 = SeqNo::from(10_000);
|
||||
assert_eq!(s1, s2);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn cmp_small_seqno_increase() {
|
||||
let s1 = SeqNo::from(1);
|
||||
let s2 = SeqNo::from(2);
|
||||
assert!(s1.lt(&s2));
|
||||
assert!(!s2.lt(&s1));
|
||||
|
||||
assert!(s2.gt(&s1));
|
||||
assert!(!s1.gt(&s2));
|
||||
|
||||
let s1 = SeqNo::from(3);
|
||||
let s2 = SeqNo::from(30_000);
|
||||
assert!(s1.lt(&s2));
|
||||
assert!(!s2.lt(&s1));
|
||||
|
||||
assert!(s2.gt(&s1));
|
||||
assert!(!s1.gt(&s2));
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn cmp_big_seqno_increase() {
|
||||
let s1 = SeqNo::from(0);
|
||||
let s2 = SeqNo::from(32_767);
|
||||
assert!(s1.lt(&s2));
|
||||
assert!(!s2.lt(&s1));
|
||||
|
||||
assert!(s2.gt(&s1));
|
||||
assert!(!s1.gt(&s2));
|
||||
|
||||
// Test equality quirk at cutoff point.
|
||||
let s1 = SeqNo::from(0);
|
||||
let s2 = SeqNo::from(32_768);
|
||||
assert!(!s1.lt(&s2));
|
||||
assert!(!s2.lt(&s1));
|
||||
|
||||
assert!(!s2.gt(&s1));
|
||||
assert!(!s1.gt(&s2));
|
||||
|
||||
let s1 = SeqNo::from(0);
|
||||
let s2 = SeqNo::from(32_769);
|
||||
assert!(!s1.lt(&s2));
|
||||
assert!(s2.lt(&s1));
|
||||
|
||||
assert!(!s2.gt(&s1));
|
||||
assert!(s1.gt(&s2));
|
||||
|
||||
let s1 = SeqNo::from(6);
|
||||
let s2 = SeqNo::from(60_000);
|
||||
assert!(!s1.lt(&s2));
|
||||
assert!(s2.lt(&s1));
|
||||
|
||||
assert!(!s2.gt(&s1));
|
||||
assert!(s1.gt(&s2));
|
||||
}
|
||||
}
|
||||
489
mycelium/src/source_table.rs
Normal file
489
mycelium/src/source_table.rs
Normal file
@@ -0,0 +1,489 @@
|
||||
use core::fmt;
|
||||
use std::{collections::HashMap, time::Duration};
|
||||
|
||||
use tokio::{sync::mpsc, task::JoinHandle};
|
||||
use tracing::error;
|
||||
|
||||
use crate::{
|
||||
babel, metric::Metric, router_id::RouterId, routing_table::RouteEntry, sequence_number::SeqNo,
|
||||
subnet::Subnet,
|
||||
};
|
||||
|
||||
/// Duration after which a source entry is deleted if it is not updated.
|
||||
const SOURCE_HOLD_DURATION: Duration = Duration::from_secs(60 * 30);
|
||||
|
||||
#[derive(Debug, Clone, PartialEq, Eq, Hash, Copy)]
|
||||
pub struct SourceKey {
|
||||
subnet: Subnet,
|
||||
router_id: RouterId,
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone, Copy)]
|
||||
pub struct FeasibilityDistance {
|
||||
metric: Metric,
|
||||
seqno: SeqNo,
|
||||
}
|
||||
|
||||
#[derive(Debug)]
|
||||
pub struct SourceTable {
|
||||
table: HashMap<SourceKey, (JoinHandle<()>, FeasibilityDistance)>,
|
||||
}
|
||||
|
||||
impl FeasibilityDistance {
|
||||
pub fn new(metric: Metric, seqno: SeqNo) -> Self {
|
||||
FeasibilityDistance { metric, seqno }
|
||||
}
|
||||
|
||||
/// Returns the metric for this `FeasibilityDistance`.
|
||||
pub const fn metric(&self) -> Metric {
|
||||
self.metric
|
||||
}
|
||||
|
||||
/// Returns the sequence number for this `FeasibilityDistance`.
|
||||
pub const fn seqno(&self) -> SeqNo {
|
||||
self.seqno
|
||||
}
|
||||
}
|
||||
|
||||
impl SourceKey {
|
||||
/// Create a new `SourceKey`.
|
||||
pub const fn new(subnet: Subnet, router_id: RouterId) -> Self {
|
||||
Self { subnet, router_id }
|
||||
}
|
||||
|
||||
/// Returns the [`RouterId`] for this `SourceKey`.
|
||||
pub const fn router_id(&self) -> RouterId {
|
||||
self.router_id
|
||||
}
|
||||
|
||||
/// Returns the [`Subnet`] for this `SourceKey`.
|
||||
pub const fn subnet(&self) -> Subnet {
|
||||
self.subnet
|
||||
}
|
||||
|
||||
/// Updates the [`RouterId`] of this `SourceKey`
|
||||
pub fn set_router_id(&mut self, router_id: RouterId) {
|
||||
self.router_id = router_id
|
||||
}
|
||||
}
|
||||
|
||||
impl SourceTable {
|
||||
pub fn new() -> Self {
|
||||
Self {
|
||||
table: HashMap::new(),
|
||||
}
|
||||
}
|
||||
|
||||
pub fn insert(
|
||||
&mut self,
|
||||
key: SourceKey,
|
||||
feas_dist: FeasibilityDistance,
|
||||
sink: mpsc::Sender<SourceKey>,
|
||||
) {
|
||||
let expiration_handle = tokio::spawn(async move {
|
||||
tokio::time::sleep(SOURCE_HOLD_DURATION).await;
|
||||
|
||||
if let Err(e) = sink.send(key).await {
|
||||
error!("Failed to notify router of expired source key {e}");
|
||||
}
|
||||
});
|
||||
// Abort the old task if present.
|
||||
if let Some((old_timeout, _)) = self.table.insert(key, (expiration_handle, feas_dist)) {
|
||||
old_timeout.abort();
|
||||
}
|
||||
}
|
||||
|
||||
/// Remove an entry from the source table.
|
||||
pub fn remove(&mut self, key: &SourceKey) {
|
||||
if let Some((old_timeout, _)) = self.table.remove(key) {
|
||||
old_timeout.abort();
|
||||
};
|
||||
}
|
||||
|
||||
/// Resets the garbage collection timer for a given source key.
|
||||
///
|
||||
/// Does nothing if the source key is not present.
|
||||
pub fn reset_timer(&mut self, key: SourceKey, sink: mpsc::Sender<SourceKey>) {
|
||||
self.table
|
||||
.entry(key)
|
||||
.and_modify(|(old_expiration_handle, _)| {
|
||||
// First cancel the existing task
|
||||
old_expiration_handle.abort();
|
||||
// Then set the new one
|
||||
*old_expiration_handle = tokio::spawn(async move {
|
||||
tokio::time::sleep(SOURCE_HOLD_DURATION).await;
|
||||
|
||||
if let Err(e) = sink.send(key).await {
|
||||
error!("Failed to notify router of expired source key {e}");
|
||||
}
|
||||
});
|
||||
});
|
||||
}
|
||||
|
||||
/// Get the [`FeasibilityDistance`] currently associated with the [`SourceKey`].
|
||||
pub fn get(&self, key: &SourceKey) -> Option<&FeasibilityDistance> {
|
||||
self.table.get(key).map(|(_, v)| v)
|
||||
}
|
||||
|
||||
/// Indicates if an update is feasible in the context of the current `SoureTable`.
|
||||
pub fn is_update_feasible(&self, update: &babel::Update) -> bool {
|
||||
// Before an update is accepted it should be checked against the feasbility condition
|
||||
// If an entry in the source table with the same source key exists, we perform the feasbility check
|
||||
// If no entry exists yet, the update is accepted as there is no better alternative available (yet)
|
||||
let source_key = SourceKey::new(update.subnet(), update.router_id());
|
||||
match self.get(&source_key) {
|
||||
Some(entry) => {
|
||||
(update.seqno().gt(&entry.seqno()))
|
||||
|| (update.seqno() == entry.seqno() && update.metric() < entry.metric())
|
||||
|| update.metric().is_infinite()
|
||||
}
|
||||
None => true,
|
||||
}
|
||||
}
|
||||
|
||||
/// Indicates if a [`RouteEntry`] is feasible according to the `SourceTable`.
|
||||
pub fn route_feasible(&self, route: &RouteEntry) -> bool {
|
||||
match self.get(&route.source()) {
|
||||
Some(fd) => {
|
||||
(route.seqno().gt(&fd.seqno))
|
||||
|| (route.seqno() == fd.seqno && route.metric() < fd.metric)
|
||||
|| route.metric().is_infinite()
|
||||
}
|
||||
None => true,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl fmt::Display for SourceKey {
|
||||
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
|
||||
f.write_fmt(format_args!(
|
||||
"{} advertised by {}",
|
||||
self.subnet, self.router_id
|
||||
))
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use tokio::sync::mpsc;
|
||||
|
||||
use crate::{
|
||||
babel,
|
||||
crypto::SecretKey,
|
||||
metric::Metric,
|
||||
peer::Peer,
|
||||
router_id::RouterId,
|
||||
routing_table::RouteEntry,
|
||||
sequence_number::SeqNo,
|
||||
source_table::{FeasibilityDistance, SourceKey, SourceTable},
|
||||
subnet::Subnet,
|
||||
};
|
||||
use std::{
|
||||
net::Ipv6Addr,
|
||||
sync::{atomic::AtomicU64, Arc},
|
||||
time::Duration,
|
||||
};
|
||||
|
||||
/// A retraction is always considered to be feasible.
|
||||
#[tokio::test]
|
||||
async fn retraction_update_is_feasible() {
|
||||
let (sink, _) = tokio::sync::mpsc::channel(1);
|
||||
let sk = SecretKey::new();
|
||||
let pk = (&sk).into();
|
||||
let sn = Subnet::new(Ipv6Addr::new(0x400, 0, 0, 0, 0, 0, 0, 1).into(), 64)
|
||||
.expect("Valid subnet in test case");
|
||||
let rid = RouterId::new(pk);
|
||||
|
||||
let mut st = SourceTable::new();
|
||||
st.insert(
|
||||
SourceKey::new(sn, rid),
|
||||
FeasibilityDistance::new(Metric::new(10), SeqNo::from(1)),
|
||||
sink,
|
||||
);
|
||||
|
||||
let update = babel::Update::new(
|
||||
Duration::from_secs(60),
|
||||
SeqNo::from(0),
|
||||
Metric::infinite(),
|
||||
sn,
|
||||
rid,
|
||||
);
|
||||
|
||||
assert!(st.is_update_feasible(&update));
|
||||
}
|
||||
|
||||
/// An update with a smaller metric but with the same seqno is feasible.
|
||||
#[tokio::test]
|
||||
async fn smaller_metric_update_is_feasible() {
|
||||
let (sink, _) = tokio::sync::mpsc::channel(1);
|
||||
let sk = SecretKey::new();
|
||||
let pk = (&sk).into();
|
||||
let sn = Subnet::new(Ipv6Addr::new(0x400, 0, 0, 0, 0, 0, 0, 1).into(), 64)
|
||||
.expect("Valid subnet in test case");
|
||||
let rid = RouterId::new(pk);
|
||||
|
||||
let mut st = SourceTable::new();
|
||||
st.insert(
|
||||
SourceKey::new(sn, rid),
|
||||
FeasibilityDistance::new(Metric::new(10), SeqNo::from(1)),
|
||||
sink,
|
||||
);
|
||||
|
||||
let update = babel::Update::new(
|
||||
Duration::from_secs(60),
|
||||
SeqNo::from(1),
|
||||
Metric::from(9),
|
||||
sn,
|
||||
rid,
|
||||
);
|
||||
|
||||
assert!(st.is_update_feasible(&update));
|
||||
}
|
||||
|
||||
/// An update with the same metric and seqno is not feasible.
|
||||
#[tokio::test]
|
||||
async fn equal_metric_update_is_unfeasible() {
|
||||
let (sink, _) = tokio::sync::mpsc::channel(1);
|
||||
let sk = SecretKey::new();
|
||||
let pk = (&sk).into();
|
||||
let sn = Subnet::new(Ipv6Addr::new(0x400, 0, 0, 0, 0, 0, 0, 1).into(), 64)
|
||||
.expect("Valid subnet in test case");
|
||||
let rid = RouterId::new(pk);
|
||||
|
||||
let mut st = SourceTable::new();
|
||||
st.insert(
|
||||
SourceKey::new(sn, rid),
|
||||
FeasibilityDistance::new(Metric::new(10), SeqNo::from(1)),
|
||||
sink,
|
||||
);
|
||||
|
||||
let update = babel::Update::new(
|
||||
Duration::from_secs(60),
|
||||
SeqNo::from(1),
|
||||
Metric::from(10),
|
||||
sn,
|
||||
rid,
|
||||
);
|
||||
|
||||
assert!(!st.is_update_feasible(&update));
|
||||
}
|
||||
|
||||
/// An update with a larger metric and the same seqno is not feasible.
|
||||
#[tokio::test]
|
||||
async fn larger_metric_update_is_unfeasible() {
|
||||
let (sink, _) = tokio::sync::mpsc::channel(1);
|
||||
let sk = SecretKey::new();
|
||||
let pk = (&sk).into();
|
||||
let sn = Subnet::new(Ipv6Addr::new(0x400, 0, 0, 0, 0, 0, 0, 1).into(), 64)
|
||||
.expect("Valid subnet in test case");
|
||||
let rid = RouterId::new(pk);
|
||||
|
||||
let mut st = SourceTable::new();
|
||||
st.insert(
|
||||
SourceKey::new(sn, rid),
|
||||
FeasibilityDistance::new(Metric::new(10), SeqNo::from(1)),
|
||||
sink,
|
||||
);
|
||||
|
||||
let update = babel::Update::new(
|
||||
Duration::from_secs(60),
|
||||
SeqNo::from(1),
|
||||
Metric::from(11),
|
||||
sn,
|
||||
rid,
|
||||
);
|
||||
|
||||
assert!(!st.is_update_feasible(&update));
|
||||
}
|
||||
|
||||
/// An update with a lower seqno is not feasible.
|
||||
#[tokio::test]
|
||||
async fn lower_seqno_update_is_unfeasible() {
|
||||
let (sink, _) = tokio::sync::mpsc::channel(1);
|
||||
let sk = SecretKey::new();
|
||||
let pk = (&sk).into();
|
||||
let sn = Subnet::new(Ipv6Addr::new(0x400, 0, 0, 0, 0, 0, 0, 1).into(), 64)
|
||||
.expect("Valid subnet in test case");
|
||||
let rid = RouterId::new(pk);
|
||||
|
||||
let mut st = SourceTable::new();
|
||||
st.insert(
|
||||
SourceKey::new(sn, rid),
|
||||
FeasibilityDistance::new(Metric::new(10), SeqNo::from(1)),
|
||||
sink,
|
||||
);
|
||||
|
||||
let update = babel::Update::new(
|
||||
Duration::from_secs(60),
|
||||
SeqNo::from(0),
|
||||
Metric::from(1),
|
||||
sn,
|
||||
rid,
|
||||
);
|
||||
|
||||
assert!(!st.is_update_feasible(&update));
|
||||
}
|
||||
|
||||
/// An update with a higher seqno is feasible.
|
||||
#[tokio::test]
|
||||
async fn higher_seqno_update_is_feasible() {
|
||||
let (sink, _) = tokio::sync::mpsc::channel(1);
|
||||
let sk = SecretKey::new();
|
||||
let pk = (&sk).into();
|
||||
let sn = Subnet::new(Ipv6Addr::new(0x400, 0, 0, 0, 0, 0, 0, 1).into(), 64)
|
||||
.expect("Valid subnet in test case");
|
||||
let rid = RouterId::new(pk);
|
||||
|
||||
let mut st = SourceTable::new();
|
||||
st.insert(
|
||||
SourceKey::new(sn, rid),
|
||||
FeasibilityDistance::new(Metric::new(10), SeqNo::from(1)),
|
||||
sink,
|
||||
);
|
||||
|
||||
let update = babel::Update::new(
|
||||
Duration::from_secs(60),
|
||||
SeqNo::from(2),
|
||||
Metric::from(200),
|
||||
sn,
|
||||
rid,
|
||||
);
|
||||
|
||||
assert!(st.is_update_feasible(&update));
|
||||
}
|
||||
|
||||
/// A route with a smaller metric but with the same seqno is feasible.
|
||||
#[tokio::test]
|
||||
async fn smaller_metric_route_is_feasible() {
|
||||
let (sink, _) = tokio::sync::mpsc::channel(1);
|
||||
let sk = SecretKey::new();
|
||||
let pk = (&sk).into();
|
||||
let sn = Subnet::new(Ipv6Addr::new(0x400, 0, 0, 0, 0, 0, 0, 1).into(), 64)
|
||||
.expect("Valid subnet in test case");
|
||||
let rid = RouterId::new(pk);
|
||||
|
||||
let source_key = SourceKey::new(sn, rid);
|
||||
|
||||
let mut st = SourceTable::new();
|
||||
st.insert(
|
||||
source_key,
|
||||
FeasibilityDistance::new(Metric::new(10), SeqNo::from(1)),
|
||||
sink,
|
||||
);
|
||||
|
||||
let (router_data_tx, _router_data_rx) = mpsc::channel(1);
|
||||
let (router_control_tx, _router_control_rx) = mpsc::unbounded_channel();
|
||||
let (dead_peer_sink, _dead_peer_stream) = mpsc::channel(1);
|
||||
let (con1, _con2) = tokio::io::duplex(1500);
|
||||
let neighbor = Peer::new(
|
||||
router_data_tx,
|
||||
router_control_tx,
|
||||
con1,
|
||||
dead_peer_sink,
|
||||
Arc::new(AtomicU64::new(0)),
|
||||
Arc::new(AtomicU64::new(0)),
|
||||
)
|
||||
.expect("Can create a dummy peer");
|
||||
|
||||
let re = RouteEntry::new(
|
||||
source_key,
|
||||
neighbor,
|
||||
Metric::new(9),
|
||||
SeqNo::from(1),
|
||||
true,
|
||||
tokio::time::Instant::now() + Duration::from_secs(60),
|
||||
);
|
||||
|
||||
assert!(st.route_feasible(&re));
|
||||
}
|
||||
|
||||
/// If a route has the same metric as the source table it is not feasible.
|
||||
#[tokio::test]
|
||||
async fn equal_metric_route_is_unfeasible() {
|
||||
let (sink, _) = tokio::sync::mpsc::channel(1);
|
||||
let sk = SecretKey::new();
|
||||
let pk = (&sk).into();
|
||||
let sn = Subnet::new(Ipv6Addr::new(0x400, 0, 0, 0, 0, 0, 0, 1).into(), 64)
|
||||
.expect("Valid subnet in test case");
|
||||
let rid = RouterId::new(pk);
|
||||
|
||||
let source_key = SourceKey::new(sn, rid);
|
||||
|
||||
let mut st = SourceTable::new();
|
||||
st.insert(
|
||||
source_key,
|
||||
FeasibilityDistance::new(Metric::new(10), SeqNo::from(1)),
|
||||
sink,
|
||||
);
|
||||
|
||||
let (router_data_tx, _router_data_rx) = mpsc::channel(1);
|
||||
let (router_control_tx, _router_control_rx) = mpsc::unbounded_channel();
|
||||
let (dead_peer_sink, _dead_peer_stream) = mpsc::channel(1);
|
||||
let (con1, _con2) = tokio::io::duplex(1500);
|
||||
let neighbor = Peer::new(
|
||||
router_data_tx,
|
||||
router_control_tx,
|
||||
con1,
|
||||
dead_peer_sink,
|
||||
Arc::new(AtomicU64::new(0)),
|
||||
Arc::new(AtomicU64::new(0)),
|
||||
)
|
||||
.expect("Can create a dummy peer");
|
||||
|
||||
let re = RouteEntry::new(
|
||||
source_key,
|
||||
neighbor,
|
||||
Metric::new(10),
|
||||
SeqNo::from(1),
|
||||
true,
|
||||
tokio::time::Instant::now() + Duration::from_secs(60),
|
||||
);
|
||||
|
||||
assert!(!st.route_feasible(&re));
|
||||
}
|
||||
|
||||
/// If a route has a higher metric as the source table it is not feasible.
|
||||
#[tokio::test]
|
||||
async fn higher_metric_route_is_unfeasible() {
|
||||
let (sink, _) = tokio::sync::mpsc::channel(1);
|
||||
let sk = SecretKey::new();
|
||||
let pk = (&sk).into();
|
||||
let sn = Subnet::new(Ipv6Addr::new(0x400, 0, 0, 0, 0, 0, 0, 1).into(), 64)
|
||||
.expect("Valid subnet in test case");
|
||||
let rid = RouterId::new(pk);
|
||||
|
||||
let source_key = SourceKey::new(sn, rid);
|
||||
|
||||
let mut st = SourceTable::new();
|
||||
st.insert(
|
||||
source_key,
|
||||
FeasibilityDistance::new(Metric::new(10), SeqNo::from(1)),
|
||||
sink,
|
||||
);
|
||||
|
||||
let (router_data_tx, _router_data_rx) = mpsc::channel(1);
|
||||
let (router_control_tx, _router_control_rx) = mpsc::unbounded_channel();
|
||||
let (dead_peer_sink, _dead_peer_stream) = mpsc::channel(1);
|
||||
let (con1, _con2) = tokio::io::duplex(1500);
|
||||
let neighbor = Peer::new(
|
||||
router_data_tx,
|
||||
router_control_tx,
|
||||
con1,
|
||||
dead_peer_sink,
|
||||
Arc::new(AtomicU64::new(0)),
|
||||
Arc::new(AtomicU64::new(0)),
|
||||
)
|
||||
.expect("Can create a dummy peer");
|
||||
|
||||
let re = RouteEntry::new(
|
||||
source_key,
|
||||
neighbor,
|
||||
Metric::new(11),
|
||||
SeqNo::from(1),
|
||||
true,
|
||||
tokio::time::Instant::now() + Duration::from_secs(60),
|
||||
);
|
||||
|
||||
assert!(!st.route_feasible(&re));
|
||||
}
|
||||
}
|
||||
277
mycelium/src/subnet.rs
Normal file
277
mycelium/src/subnet.rs
Normal file
@@ -0,0 +1,277 @@
|
||||
//! A dedicated subnet module.
|
||||
//!
|
||||
//! The standard library only exposes [`IpAddr`], and types related to
|
||||
//! specific IPv4 and IPv6 addresses. It does not however, expose dedicated types to represent
|
||||
//! appropriate subnets.
|
||||
//!
|
||||
//! This code is not meant to fully support subnets, but rather only the subset as needed by the
|
||||
//! main application code. As such, this implementation is optimized for the specific use case, and
|
||||
//! might not be optimal for other uses.
|
||||
|
||||
use core::fmt;
|
||||
use std::{
|
||||
hash::Hash,
|
||||
net::{IpAddr, Ipv6Addr},
|
||||
str::FromStr,
|
||||
};
|
||||
|
||||
use ipnet::IpNet;
|
||||
|
||||
/// Representation of a subnet. A subnet can be either IPv4 or IPv6.
|
||||
#[derive(Debug, Clone, Copy, Eq, PartialOrd, Ord)]
|
||||
pub struct Subnet {
|
||||
inner: IpNet,
|
||||
}
|
||||
|
||||
/// An error returned when creating a new [`Subnet`] with an invalid prefix length.
|
||||
///
|
||||
/// For IPv4, the max prefix length is 32, and for IPv6 it is 128;
|
||||
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
|
||||
pub struct PrefixLenError;
|
||||
|
||||
impl Subnet {
|
||||
/// Create a new `Subnet` from the given [`IpAddr`] and prefix length.
|
||||
pub fn new(addr: IpAddr, prefix_len: u8) -> Result<Subnet, PrefixLenError> {
|
||||
Ok(Self {
|
||||
inner: IpNet::new(addr, prefix_len).map_err(|_| PrefixLenError)?,
|
||||
})
|
||||
}
|
||||
|
||||
/// Returns the size of the prefix in bits.
|
||||
pub fn prefix_len(&self) -> u8 {
|
||||
self.inner.prefix_len()
|
||||
}
|
||||
|
||||
/// Retuns the address in this subnet.
|
||||
///
|
||||
/// The returned address is a full IP address, used to construct this `Subnet`.
|
||||
///
|
||||
/// # Examples
|
||||
///
|
||||
/// ```
|
||||
/// use mycelium::subnet::Subnet;
|
||||
/// use std::net::Ipv6Addr;
|
||||
///
|
||||
/// let address = Ipv6Addr::new(12,34,56,78,90,0xab,0xcd,0xef).into();
|
||||
/// let subnet = Subnet::new(address, 64).unwrap();
|
||||
///
|
||||
/// assert_eq!(subnet.address(), address);
|
||||
/// ```
|
||||
pub fn address(&self) -> IpAddr {
|
||||
self.inner.addr()
|
||||
}
|
||||
|
||||
/// Checks if this `Subnet` contains the provided `Subnet`, i.e. all addresses of the provided
|
||||
/// `Subnet` are also part of this `Subnet`
|
||||
///
|
||||
/// # Examples
|
||||
///
|
||||
/// ```
|
||||
/// use mycelium::subnet::Subnet;
|
||||
/// use std::net::Ipv4Addr;
|
||||
///
|
||||
/// let global = Subnet::new(Ipv4Addr::new(0,0,0,0).into(), 0).expect("Defined a valid subnet");
|
||||
/// let local = Subnet::new(Ipv4Addr::new(10,0,0,0).into(), 8).expect("Defined a valid subnet");
|
||||
///
|
||||
/// assert!(global.contains_subnet(&local));
|
||||
/// assert!(!local.contains_subnet(&global));
|
||||
/// ```
|
||||
pub fn contains_subnet(&self, other: &Self) -> bool {
|
||||
self.inner.contains(&other.inner)
|
||||
}
|
||||
|
||||
/// Checks if this `Subnet` contains the provided [`IpAddr`].
|
||||
///
|
||||
/// # Examples
|
||||
///
|
||||
/// ```
|
||||
/// use mycelium::subnet::Subnet;
|
||||
/// use std::net::{Ipv4Addr,Ipv6Addr};
|
||||
///
|
||||
/// let ip_1 = Ipv6Addr::new(12,34,56,78,90,0xab,0xcd,0xef).into();
|
||||
/// let ip_2 = Ipv6Addr::new(90,0xab,0xcd,0xef,12,34,56,78).into();
|
||||
/// let ip_3 = Ipv4Addr::new(10,1,2,3).into();
|
||||
/// let subnet = Subnet::new(Ipv6Addr::new(12,34,5,6,7,8,9,0).into(), 32).unwrap();
|
||||
///
|
||||
/// assert!(subnet.contains_ip(ip_1));
|
||||
/// assert!(!subnet.contains_ip(ip_2));
|
||||
/// assert!(!subnet.contains_ip(ip_3));
|
||||
/// ```
|
||||
pub fn contains_ip(&self, ip: IpAddr) -> bool {
|
||||
self.inner.contains(&ip)
|
||||
}
|
||||
|
||||
/// Returns the network part of the `Subnet`. All non prefix bits are set to 0.
|
||||
///
|
||||
/// # Examples
|
||||
///
|
||||
/// ```
|
||||
/// use mycelium::subnet::Subnet;
|
||||
/// use std::net::{IpAddr, Ipv4Addr,Ipv6Addr};
|
||||
///
|
||||
/// let subnet_1 = Subnet::new(Ipv6Addr::new(12,34,56,78,90,0xab,0xcd,0xef).into(),
|
||||
/// 32).unwrap();
|
||||
/// let subnet_2 = Subnet::new(Ipv4Addr::new(10,1,2,3).into(), 8).unwrap();
|
||||
///
|
||||
/// assert_eq!(subnet_1.network(), IpAddr::V6(Ipv6Addr::new(12,34,0,0,0,0,0,0)));
|
||||
/// assert_eq!(subnet_2.network(), IpAddr::V4(Ipv4Addr::new(10,0,0,0)));
|
||||
/// ```
|
||||
pub fn network(&self) -> IpAddr {
|
||||
self.inner.network()
|
||||
}
|
||||
|
||||
/// Returns the braodcast address for the subnet.
|
||||
///
|
||||
/// # Examples
|
||||
///
|
||||
/// ```
|
||||
/// use mycelium::subnet::Subnet;
|
||||
/// use std::net::{IpAddr, Ipv4Addr,Ipv6Addr};
|
||||
///
|
||||
/// let subnet_1 = Subnet::new(Ipv6Addr::new(12,34,56,78,90,0xab,0xcd,0xef).into(),
|
||||
/// 32).unwrap();
|
||||
/// let subnet_2 = Subnet::new(Ipv4Addr::new(10,1,2,3).into(), 8).unwrap();
|
||||
///
|
||||
/// assert_eq!(subnet_1.broadcast_addr(),
|
||||
/// IpAddr::V6(Ipv6Addr::new(12,34,0xffff,0xffff,0xffff,0xffff,0xffff,0xffff)));
|
||||
/// assert_eq!(subnet_2.broadcast_addr(), IpAddr::V4(Ipv4Addr::new(10,255,255,255)));
|
||||
/// ```
|
||||
pub fn broadcast_addr(&self) -> IpAddr {
|
||||
self.inner.broadcast()
|
||||
}
|
||||
|
||||
/// Returns the netmask of the subnet as an [`IpAddr`].
|
||||
pub fn mask(&self) -> IpAddr {
|
||||
self.inner.netmask()
|
||||
}
|
||||
}
|
||||
|
||||
impl From<Ipv6Addr> for Subnet {
|
||||
fn from(value: Ipv6Addr) -> Self {
|
||||
Self::new(value.into(), 128).expect("128 is a valid subnet size for an IPv6 address; qed")
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone)]
|
||||
/// An error indicating a malformed subnet
|
||||
pub struct SubnetParseError {
|
||||
_private: (),
|
||||
}
|
||||
|
||||
impl SubnetParseError {
|
||||
/// Create a new SubnetParseError
|
||||
fn new() -> Self {
|
||||
Self { _private: () }
|
||||
}
|
||||
}
|
||||
|
||||
impl core::fmt::Display for SubnetParseError {
|
||||
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
|
||||
f.pad("malformed subnet")
|
||||
}
|
||||
}
|
||||
|
||||
impl std::error::Error for SubnetParseError {}
|
||||
|
||||
impl FromStr for Subnet {
|
||||
type Err = SubnetParseError;
|
||||
|
||||
fn from_str(s: &str) -> Result<Self, Self::Err> {
|
||||
if let Ok(ipnet) = s.parse::<ipnet::IpNet>() {
|
||||
return Ok(
|
||||
Subnet::new(ipnet.addr(), ipnet.prefix_len()).expect("Parsed subnet size is valid")
|
||||
);
|
||||
}
|
||||
|
||||
// Try to parse as an IP address (convert to /32 or /128 subnet)
|
||||
if let Ok(ip) = s.parse::<std::net::IpAddr>() {
|
||||
let prefix_len = match ip {
|
||||
std::net::IpAddr::V4(_) => 32,
|
||||
std::net::IpAddr::V6(_) => 128,
|
||||
};
|
||||
return Ok(Subnet::new(ip, prefix_len).expect("Static subnet sizes are valid"));
|
||||
}
|
||||
|
||||
Err(SubnetParseError::new())
|
||||
}
|
||||
}
|
||||
|
||||
impl fmt::Display for Subnet {
|
||||
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
|
||||
write!(f, "{}", self.inner)
|
||||
}
|
||||
}
|
||||
|
||||
impl PartialEq for Subnet {
|
||||
fn eq(&self, other: &Self) -> bool {
|
||||
// Quic check, subnets of different sizes are never equal.
|
||||
if self.prefix_len() != other.prefix_len() {
|
||||
return false;
|
||||
}
|
||||
|
||||
// Full check
|
||||
self.network() == other.network()
|
||||
}
|
||||
}
|
||||
|
||||
impl Hash for Subnet {
|
||||
fn hash<H: std::hash::Hasher>(&self, state: &mut H) {
|
||||
// First write the subnet size
|
||||
state.write_u8(self.prefix_len());
|
||||
// Then write the IP of the network. This sets the non prefix bits to 0, so hash values
|
||||
// will be equal according to the PartialEq rules.
|
||||
self.network().hash(state)
|
||||
}
|
||||
}
|
||||
|
||||
impl fmt::Display for PrefixLenError {
|
||||
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
|
||||
f.write_str("Invalid prefix length for this address")
|
||||
}
|
||||
}
|
||||
|
||||
impl std::error::Error for PrefixLenError {}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use std::net::{Ipv4Addr, Ipv6Addr};
|
||||
|
||||
use super::Subnet;
|
||||
|
||||
#[test]
|
||||
fn test_subnet_equality() {
|
||||
let subnet_1 =
|
||||
Subnet::new(Ipv6Addr::new(12, 23, 34, 45, 56, 67, 78, 89).into(), 64).unwrap();
|
||||
let subnet_2 =
|
||||
Subnet::new(Ipv6Addr::new(12, 23, 34, 45, 67, 78, 89, 90).into(), 64).unwrap();
|
||||
let subnet_3 =
|
||||
Subnet::new(Ipv6Addr::new(12, 23, 34, 40, 67, 78, 89, 90).into(), 64).unwrap();
|
||||
let subnet_4 = Subnet::new(Ipv6Addr::new(12, 23, 34, 45, 0, 0, 0, 0).into(), 64).unwrap();
|
||||
let subnet_5 = Subnet::new(
|
||||
Ipv6Addr::new(12, 23, 34, 45, 0xffff, 0xffff, 0xffff, 0xffff).into(),
|
||||
64,
|
||||
)
|
||||
.unwrap();
|
||||
let subnet_6 =
|
||||
Subnet::new(Ipv6Addr::new(12, 23, 34, 45, 56, 67, 78, 89).into(), 63).unwrap();
|
||||
|
||||
assert_eq!(subnet_1, subnet_2);
|
||||
assert_ne!(subnet_1, subnet_3);
|
||||
assert_eq!(subnet_1, subnet_4);
|
||||
assert_eq!(subnet_1, subnet_5);
|
||||
assert_ne!(subnet_1, subnet_6);
|
||||
|
||||
let subnet_1 = Subnet::new(Ipv4Addr::new(10, 1, 2, 3).into(), 24).unwrap();
|
||||
let subnet_2 = Subnet::new(Ipv4Addr::new(10, 1, 2, 102).into(), 24).unwrap();
|
||||
let subnet_3 = Subnet::new(Ipv4Addr::new(10, 1, 4, 3).into(), 24).unwrap();
|
||||
let subnet_4 = Subnet::new(Ipv4Addr::new(10, 1, 2, 0).into(), 24).unwrap();
|
||||
let subnet_5 = Subnet::new(Ipv4Addr::new(10, 1, 2, 255).into(), 24).unwrap();
|
||||
let subnet_6 = Subnet::new(Ipv4Addr::new(10, 1, 2, 3).into(), 16).unwrap();
|
||||
|
||||
assert_eq!(subnet_1, subnet_2);
|
||||
assert_ne!(subnet_1, subnet_3);
|
||||
assert_eq!(subnet_1, subnet_4);
|
||||
assert_eq!(subnet_1, subnet_5);
|
||||
assert_ne!(subnet_1, subnet_6);
|
||||
}
|
||||
}
|
||||
29
mycelium/src/task.rs
Normal file
29
mycelium/src/task.rs
Normal file
@@ -0,0 +1,29 @@
|
||||
//! This module provides some task abstractions which add custom logic to the default behavior.
|
||||
|
||||
/// A handle to a task, which is only used to abort the task. In case this handle is dropped, the
|
||||
/// task is cancelled automatically.
|
||||
pub struct AbortHandle(tokio::task::AbortHandle);
|
||||
|
||||
impl AbortHandle {
|
||||
/// Abort the task this `AbortHandle` is referencing. It is safe to call this method multiple
|
||||
/// times, but only the first call is actually usefull. It is possible for the task to still
|
||||
/// finish succesfully, even after abort is called.
|
||||
#[inline]
|
||||
pub fn abort(&self) {
|
||||
self.0.abort()
|
||||
}
|
||||
}
|
||||
|
||||
impl Drop for AbortHandle {
|
||||
#[inline]
|
||||
fn drop(&mut self) {
|
||||
self.0.abort()
|
||||
}
|
||||
}
|
||||
|
||||
impl From<tokio::task::AbortHandle> for AbortHandle {
|
||||
#[inline]
|
||||
fn from(value: tokio::task::AbortHandle) -> Self {
|
||||
Self(value)
|
||||
}
|
||||
}
|
||||
54
mycelium/src/tun.rs
Normal file
54
mycelium/src/tun.rs
Normal file
@@ -0,0 +1,54 @@
|
||||
//! The tun module implements a platform independent Tun interface.
|
||||
|
||||
#[cfg(any(
|
||||
target_os = "linux",
|
||||
all(target_os = "macos", not(feature = "mactunfd")),
|
||||
target_os = "windows"
|
||||
))]
|
||||
use crate::subnet::Subnet;
|
||||
|
||||
#[cfg(any(
|
||||
target_os = "linux",
|
||||
all(target_os = "macos", not(feature = "mactunfd")),
|
||||
target_os = "windows"
|
||||
))]
|
||||
pub struct TunConfig {
|
||||
pub name: String,
|
||||
pub node_subnet: Subnet,
|
||||
pub route_subnet: Subnet,
|
||||
}
|
||||
|
||||
#[cfg(any(
|
||||
target_os = "android",
|
||||
target_os = "ios",
|
||||
all(target_os = "macos", feature = "mactunfd"),
|
||||
))]
|
||||
pub struct TunConfig {
|
||||
pub tun_fd: i32,
|
||||
}
|
||||
#[cfg(target_os = "linux")]
|
||||
mod linux;
|
||||
|
||||
#[cfg(target_os = "linux")]
|
||||
pub use linux::new;
|
||||
|
||||
#[cfg(all(target_os = "macos", not(feature = "mactunfd")))]
|
||||
mod darwin;
|
||||
|
||||
#[cfg(all(target_os = "macos", not(feature = "mactunfd")))]
|
||||
pub use darwin::new;
|
||||
|
||||
#[cfg(target_os = "windows")]
|
||||
mod windows;
|
||||
#[cfg(target_os = "windows")]
|
||||
pub use windows::new;
|
||||
|
||||
#[cfg(target_os = "android")]
|
||||
mod android;
|
||||
#[cfg(target_os = "android")]
|
||||
pub use android::new;
|
||||
|
||||
#[cfg(any(target_os = "ios", all(target_os = "macos", feature = "mactunfd")))]
|
||||
mod ios;
|
||||
#[cfg(any(target_os = "ios", all(target_os = "macos", feature = "mactunfd")))]
|
||||
pub use ios::new;
|
||||
108
mycelium/src/tun/android.rs
Normal file
108
mycelium/src/tun/android.rs
Normal file
@@ -0,0 +1,108 @@
|
||||
//! android specific tun interface setup.
|
||||
|
||||
use std::io::{self};
|
||||
|
||||
use futures::{Sink, Stream};
|
||||
use tokio::{
|
||||
io::{AsyncReadExt, AsyncWriteExt},
|
||||
select,
|
||||
sync::mpsc,
|
||||
};
|
||||
use tracing::{error, info};
|
||||
|
||||
use crate::crypto::PacketBuffer;
|
||||
use crate::tun::TunConfig;
|
||||
|
||||
// TODO
|
||||
const LINK_MTU: i32 = 1400;
|
||||
|
||||
/// Create a new tun interface and set required routes
|
||||
///
|
||||
/// # Panics
|
||||
///
|
||||
/// This function will panic if called outside of the context of a tokio runtime.
|
||||
pub async fn new(
|
||||
tun_config: TunConfig,
|
||||
) -> Result<
|
||||
(
|
||||
impl Stream<Item = io::Result<PacketBuffer>>,
|
||||
impl Sink<PacketBuffer, Error = impl std::error::Error> + Clone,
|
||||
),
|
||||
Box<dyn std::error::Error>,
|
||||
> {
|
||||
let name = "tun0";
|
||||
let mut tun = create_tun_interface(name, tun_config.tun_fd)?;
|
||||
|
||||
let (tun_sink, mut sink_receiver) = mpsc::channel::<PacketBuffer>(1000);
|
||||
let (tun_stream, stream_receiver) = mpsc::unbounded_channel();
|
||||
|
||||
// Spawn a single task to manage the TUN interface
|
||||
tokio::spawn(async move {
|
||||
let mut buf_hold = None;
|
||||
loop {
|
||||
let mut buf = if let Some(buf) = buf_hold.take() {
|
||||
buf
|
||||
} else {
|
||||
PacketBuffer::new()
|
||||
};
|
||||
|
||||
select! {
|
||||
data = sink_receiver.recv() => {
|
||||
match data {
|
||||
None => return,
|
||||
Some(data) => {
|
||||
if let Err(e) = tun.write(&data).await {
|
||||
error!("Failed to send data to tun interface {e}");
|
||||
}
|
||||
}
|
||||
}
|
||||
// Save the buffer as we didn't use it
|
||||
buf_hold = Some(buf);
|
||||
}
|
||||
read_result = tun.read(buf.buffer_mut()) => {
|
||||
let rr = read_result.map(|n| {
|
||||
buf.set_size(n);
|
||||
buf
|
||||
});
|
||||
|
||||
|
||||
if tun_stream.send(rr).is_err() {
|
||||
error!("Could not forward data to tun stream, receiver is gone");
|
||||
break;
|
||||
};
|
||||
}
|
||||
}
|
||||
}
|
||||
info!("Stop reading from / writing to tun interface");
|
||||
});
|
||||
|
||||
Ok((
|
||||
tokio_stream::wrappers::UnboundedReceiverStream::new(stream_receiver),
|
||||
tokio_util::sync::PollSender::new(tun_sink),
|
||||
))
|
||||
}
|
||||
|
||||
/// Create a new TUN interface
|
||||
fn create_tun_interface(
|
||||
name: &str,
|
||||
tun_fd: i32,
|
||||
) -> Result<tun::AsyncDevice, Box<dyn std::error::Error>> {
|
||||
let mut config = tun::Configuration::default();
|
||||
config
|
||||
.name(name)
|
||||
.layer(tun::Layer::L3)
|
||||
.mtu(LINK_MTU)
|
||||
.queues(1)
|
||||
.raw_fd(tun_fd)
|
||||
.up();
|
||||
info!("create_tun_interface");
|
||||
let tun = match tun::create_as_async(&config) {
|
||||
Ok(tun) => tun,
|
||||
Err(err) => {
|
||||
error!("[android]failed to create tun interface: {err}");
|
||||
return Err(Box::new(err));
|
||||
}
|
||||
};
|
||||
|
||||
Ok(tun)
|
||||
}
|
||||
347
mycelium/src/tun/darwin.rs
Normal file
347
mycelium/src/tun/darwin.rs
Normal file
@@ -0,0 +1,347 @@
|
||||
//! macos specific tun interface setup.
|
||||
|
||||
use std::{
|
||||
ffi::CString,
|
||||
io::{self, IoSlice},
|
||||
net::IpAddr,
|
||||
os::fd::AsRawFd,
|
||||
str::FromStr,
|
||||
};
|
||||
|
||||
use futures::{Sink, Stream};
|
||||
use nix::sys::socket::SockaddrIn6;
|
||||
use tokio::{
|
||||
io::{AsyncReadExt, AsyncWriteExt},
|
||||
select,
|
||||
sync::mpsc,
|
||||
};
|
||||
use tracing::{debug, error, info, warn};
|
||||
|
||||
use crate::crypto::PacketBuffer;
|
||||
use crate::subnet::Subnet;
|
||||
use crate::tun::TunConfig;
|
||||
|
||||
// TODO
|
||||
const LINK_MTU: i32 = 1400;
|
||||
|
||||
/// The 4 byte packet header written before a packet is sent on the TUN
|
||||
// TODO: figure out structure and values, but for now this seems to work.
|
||||
const HEADER: [u8; 4] = [0, 0, 0, 30];
|
||||
|
||||
const IN6_IFF_NODAD: u32 = 0x0020; // netinet6/in6_var.h
|
||||
const IN6_IFF_SECURED: u32 = 0x0400; // netinet6/in6_var.h
|
||||
const ND6_INFINITE_LIFETIME: u32 = 0xFFFFFFFF; // netinet6/nd6.h
|
||||
|
||||
/// Wrapper for an OS-specific interface name
|
||||
// Allways hold the max size of an interface. This includes the 0 byte for termination.
|
||||
// repr transparent so this can be used with libc calls.
|
||||
#[repr(transparent)]
|
||||
#[derive(Clone, Copy)]
|
||||
pub struct IfaceName([libc::c_char; libc::IFNAMSIZ as _]);
|
||||
|
||||
/// Wrapped interface handle.
|
||||
#[derive(Clone, Copy)]
|
||||
struct Iface {
|
||||
/// Name of the interface
|
||||
iface_name: IfaceName,
|
||||
}
|
||||
|
||||
/// Struct to add IPv6 route to interface
|
||||
#[repr(C)]
|
||||
pub struct IfaliasReq {
|
||||
ifname: IfaceName,
|
||||
addr: SockaddrIn6,
|
||||
dst_addr: SockaddrIn6,
|
||||
mask: SockaddrIn6,
|
||||
flags: u32,
|
||||
lifetime: AddressLifetime,
|
||||
}
|
||||
|
||||
#[repr(C)]
|
||||
pub struct AddressLifetime {
|
||||
/// Not used for userspace -> kernel space
|
||||
expire: libc::time_t,
|
||||
/// Not used for userspace -> kernel space
|
||||
preferred: libc::time_t,
|
||||
vltime: u32,
|
||||
pltime: u32,
|
||||
}
|
||||
|
||||
/// Create a new tun interface and set required routes
|
||||
///
|
||||
/// # Panics
|
||||
///
|
||||
/// This function will panic if called outside of the context of a tokio runtime.
|
||||
pub async fn new(
|
||||
tun_config: TunConfig,
|
||||
) -> Result<
|
||||
(
|
||||
impl Stream<Item = io::Result<PacketBuffer>>,
|
||||
impl Sink<PacketBuffer, Error = impl std::error::Error> + Clone,
|
||||
),
|
||||
Box<dyn std::error::Error>,
|
||||
> {
|
||||
let tun_name = find_available_utun_name(&tun_config.name)?;
|
||||
|
||||
let mut tun = match create_tun_interface(&tun_name) {
|
||||
Ok(tun) => tun,
|
||||
Err(e) => {
|
||||
error!(tun_name=%tun_name, err=%e, "Could not create TUN device. Make sure the name is not yet in use, and you have sufficient privileges to create a network device");
|
||||
return Err(e);
|
||||
}
|
||||
};
|
||||
let iface = Iface::by_name(&tun_name)?;
|
||||
iface.add_address(tun_config.node_subnet, tun_config.route_subnet)?;
|
||||
|
||||
let (tun_sink, mut sink_receiver) = mpsc::channel::<PacketBuffer>(1000);
|
||||
let (tun_stream, stream_receiver) = mpsc::unbounded_channel();
|
||||
|
||||
// Spawn a single task to manage the TUN interface
|
||||
tokio::spawn(async move {
|
||||
let mut buf_hold = None;
|
||||
loop {
|
||||
let mut buf: PacketBuffer = buf_hold.take().unwrap_or_default();
|
||||
|
||||
select! {
|
||||
data = sink_receiver.recv() => {
|
||||
match data {
|
||||
None => return,
|
||||
Some(data) => {
|
||||
// We need to append a 4 byte header here
|
||||
if let Err(e) = tun.write_vectored(&[IoSlice::new(&HEADER), IoSlice::new(&data)]).await {
|
||||
error!("Failed to send data to tun interface {e}");
|
||||
}
|
||||
}
|
||||
}
|
||||
// Save the buffer as we didn't use it
|
||||
buf_hold = Some(buf);
|
||||
}
|
||||
read_result = tun.read(buf.buffer_mut()) => {
|
||||
let rr = read_result.map(|n| {
|
||||
buf.set_size(n);
|
||||
// Trim header
|
||||
buf.buffer_mut().copy_within(4.., 0);
|
||||
buf.set_size(n-4);
|
||||
buf
|
||||
});
|
||||
|
||||
|
||||
if tun_stream.send(rr).is_err() {
|
||||
error!("Could not forward data to tun stream, receiver is gone");
|
||||
break;
|
||||
};
|
||||
}
|
||||
}
|
||||
}
|
||||
info!("Stop reading from / writing to tun interface");
|
||||
});
|
||||
|
||||
Ok((
|
||||
tokio_stream::wrappers::UnboundedReceiverStream::new(stream_receiver),
|
||||
tokio_util::sync::PollSender::new(tun_sink),
|
||||
))
|
||||
}
|
||||
|
||||
/// Checks if a name is valid for a utun interface
|
||||
///
|
||||
/// Rules:
|
||||
/// - must start with "utun"
|
||||
/// - followed by only digits
|
||||
/// - 15 chars total at most
|
||||
fn validate_utun_name(input: &str) -> bool {
|
||||
if input.len() > 15 {
|
||||
return false;
|
||||
}
|
||||
if !input.starts_with("utun") {
|
||||
return false;
|
||||
}
|
||||
|
||||
input
|
||||
.strip_prefix("utun")
|
||||
.expect("We just checked that name starts with 'utun' so this is always some")
|
||||
.parse::<u64>()
|
||||
.is_ok()
|
||||
}
|
||||
|
||||
/// Validates the user-supplied TUN interface name
|
||||
///
|
||||
/// - If the name is valid and not in use, it will be the TUN name
|
||||
/// - If the name is valid but already in use, an error will be thrown
|
||||
/// - If the name is not valid, we try to find the first freely available TUN name
|
||||
fn find_available_utun_name(preferred_name: &str) -> Result<String, io::Error> {
|
||||
// Get the list of existing utun interfaces.
|
||||
let interfaces = netdev::get_interfaces();
|
||||
let utun_interfaces: Vec<_> = interfaces
|
||||
.iter()
|
||||
.filter_map(|iface| {
|
||||
if iface.name.starts_with("utun") {
|
||||
Some(iface.name.as_str())
|
||||
} else {
|
||||
None
|
||||
}
|
||||
})
|
||||
.collect();
|
||||
|
||||
// Check if the preferred name is valid and not in use.
|
||||
if validate_utun_name(preferred_name) && !utun_interfaces.contains(&preferred_name) {
|
||||
return Ok(preferred_name.to_string());
|
||||
}
|
||||
|
||||
// If the preferred name is invalid or already in use, find the first available utun name.
|
||||
if !validate_utun_name(preferred_name) {
|
||||
warn!(tun_name=%preferred_name, "Invalid TUN name. Looking for the first available TUN name");
|
||||
} else {
|
||||
warn!(tun_name=%preferred_name, "TUN name already in use. Looking for the next available TUN name.");
|
||||
}
|
||||
|
||||
// Extract and sort the utun numbers.
|
||||
let mut utun_numbers = utun_interfaces
|
||||
.iter()
|
||||
.filter_map(|iface| iface[4..].parse::<usize>().ok())
|
||||
.collect::<Vec<_>>();
|
||||
|
||||
utun_numbers.sort_unstable();
|
||||
|
||||
// Find the first available utun index.
|
||||
let mut first_free_index = 0;
|
||||
for (i, &num) in utun_numbers.iter().enumerate() {
|
||||
if num != i {
|
||||
first_free_index = i;
|
||||
break;
|
||||
}
|
||||
first_free_index = i + 1;
|
||||
}
|
||||
|
||||
// Create new utun name based on the first free index.
|
||||
let new_utun_name = format!("utun{}", first_free_index);
|
||||
if validate_utun_name(&new_utun_name) {
|
||||
info!(tun_name=%new_utun_name, "Automatically assigned TUN name.");
|
||||
Ok(new_utun_name)
|
||||
} else {
|
||||
error!("No available TUN name found");
|
||||
Err(io::Error::new(
|
||||
io::ErrorKind::Other,
|
||||
"No available TUN name",
|
||||
))
|
||||
}
|
||||
}
|
||||
|
||||
/// Create a new TUN interface
|
||||
fn create_tun_interface(name: &str) -> Result<tun::AsyncDevice, Box<dyn std::error::Error>> {
|
||||
let mut config = tun::Configuration::default();
|
||||
config
|
||||
.name(name)
|
||||
.layer(tun::Layer::L3)
|
||||
.mtu(LINK_MTU)
|
||||
.queues(1)
|
||||
.up();
|
||||
let tun = tun::create_as_async(&config)?;
|
||||
|
||||
Ok(tun)
|
||||
}
|
||||
|
||||
impl IfaceName {
|
||||
fn as_ptr(&self) -> *const libc::c_char {
|
||||
self.0.as_ptr()
|
||||
}
|
||||
}
|
||||
|
||||
impl FromStr for IfaceName {
|
||||
type Err = &'static str;
|
||||
|
||||
fn from_str(s: &str) -> Result<Self, Self::Err> {
|
||||
// Equal len is not allowed because we need to add the 0 byte terminator.
|
||||
if s.len() >= libc::IFNAMSIZ {
|
||||
return Err("Interface name too long");
|
||||
}
|
||||
|
||||
// TODO: Is this err possible in a &str?
|
||||
let raw_name = CString::new(s).map_err(|_| "Interface name contains 0 byte")?;
|
||||
let mut backing = [0; libc::IFNAMSIZ];
|
||||
let name_bytes = raw_name.to_bytes_with_nul();
|
||||
backing[..name_bytes.len()].copy_from_slice(name_bytes);
|
||||
// SAFETY: This doesn't do any weird things with the bits when converting from u8 to i8
|
||||
let backing = unsafe { std::mem::transmute::<[u8; 16], [i8; 16]>(backing) };
|
||||
Ok(Self(backing))
|
||||
}
|
||||
}
|
||||
|
||||
impl Iface {
|
||||
/// Retrieve the link index of an interface with the given name
|
||||
fn by_name(name: &str) -> Result<Iface, Box<dyn std::error::Error>> {
|
||||
let iface_name: IfaceName = name.parse()?;
|
||||
match unsafe { libc::if_nametoindex(iface_name.as_ptr()) } {
|
||||
0 => Err(std::io::Error::new(
|
||||
std::io::ErrorKind::NotFound,
|
||||
"interface not found",
|
||||
))?,
|
||||
_ => Ok(Iface { iface_name }),
|
||||
}
|
||||
}
|
||||
|
||||
/// Add an address to an interface.
|
||||
///
|
||||
/// # Panics
|
||||
///
|
||||
/// Only IPv6 is supported, this function will panic when adding an IPv4 subnet.
|
||||
fn add_address(
|
||||
&self,
|
||||
subnet: Subnet,
|
||||
route_subnet: Subnet,
|
||||
) -> Result<(), Box<dyn std::error::Error>> {
|
||||
let addr = if let IpAddr::V6(addr) = subnet.address() {
|
||||
addr
|
||||
} else {
|
||||
panic!("IPv4 subnets are not supported");
|
||||
};
|
||||
let mask_addr = if let IpAddr::V6(mask) = route_subnet.mask() {
|
||||
mask
|
||||
} else {
|
||||
// We already know we are IPv6 here
|
||||
panic!("IPv4 routes are not supported");
|
||||
};
|
||||
|
||||
let sock_addr = SockaddrIn6::from(std::net::SocketAddrV6::new(addr, 0, 0, 0));
|
||||
let mask = SockaddrIn6::from(std::net::SocketAddrV6::new(mask_addr, 0, 0, 0));
|
||||
|
||||
let req = IfaliasReq {
|
||||
ifname: self.iface_name,
|
||||
addr: sock_addr,
|
||||
// SAFETY: kernel expects this to be fully zeroed
|
||||
dst_addr: unsafe { std::mem::zeroed() },
|
||||
mask,
|
||||
flags: IN6_IFF_NODAD | IN6_IFF_SECURED,
|
||||
lifetime: AddressLifetime {
|
||||
expire: 0,
|
||||
preferred: 0,
|
||||
vltime: ND6_INFINITE_LIFETIME,
|
||||
pltime: ND6_INFINITE_LIFETIME,
|
||||
},
|
||||
};
|
||||
|
||||
let sock = random_socket()?;
|
||||
|
||||
match unsafe { siocaifaddr_in6(sock.as_raw_fd(), &req) } {
|
||||
Err(e) => {
|
||||
error!("Failed to add ipv6 addresst to interface {e}");
|
||||
Err(std::io::Error::last_os_error())?
|
||||
}
|
||||
Ok(_) => {
|
||||
debug!("Added {subnet} to tun interfacel");
|
||||
Ok(())
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
// Create a socket to talk to the kernel.
|
||||
fn random_socket() -> Result<std::net::UdpSocket, std::io::Error> {
|
||||
std::net::UdpSocket::bind("[::1]:0")
|
||||
}
|
||||
|
||||
nix::ioctl_write_ptr!(
|
||||
/// Add an IPv6 subnet to an interface.
|
||||
siocaifaddr_in6,
|
||||
b'i',
|
||||
26,
|
||||
IfaliasReq
|
||||
);
|
||||
100
mycelium/src/tun/ios.rs
Normal file
100
mycelium/src/tun/ios.rs
Normal file
@@ -0,0 +1,100 @@
|
||||
//! ios specific tun interface setup.
|
||||
|
||||
use std::io::{self, IoSlice};
|
||||
|
||||
use futures::{Sink, Stream};
|
||||
use tokio::{
|
||||
io::{AsyncReadExt, AsyncWriteExt},
|
||||
select,
|
||||
sync::mpsc,
|
||||
};
|
||||
use tracing::{error, info};
|
||||
|
||||
use crate::crypto::PacketBuffer;
|
||||
use crate::tun::TunConfig;
|
||||
|
||||
// TODO
|
||||
const LINK_MTU: i32 = 1400;
|
||||
|
||||
/// The 4 byte packet header written before a packet is sent on the TUN
|
||||
// TODO: figure out structure and values, but for now this seems to work.
|
||||
const HEADER: [u8; 4] = [0, 0, 0, 30];
|
||||
|
||||
/// Create a new tun interface and set required routes
|
||||
///
|
||||
/// # Panics
|
||||
///
|
||||
/// This function will panic if called outside of the context of a tokio runtime.
|
||||
pub async fn new(
|
||||
tun_config: TunConfig,
|
||||
) -> Result<
|
||||
(
|
||||
impl Stream<Item = io::Result<PacketBuffer>>,
|
||||
impl Sink<PacketBuffer, Error = impl std::error::Error> + Clone,
|
||||
),
|
||||
Box<dyn std::error::Error>,
|
||||
> {
|
||||
let mut tun = create_tun_interface(tun_config.tun_fd)?;
|
||||
|
||||
let (tun_sink, mut sink_receiver) = mpsc::channel::<PacketBuffer>(1000);
|
||||
let (tun_stream, stream_receiver) = mpsc::unbounded_channel();
|
||||
|
||||
// Spawn a single task to manage the TUN interface
|
||||
tokio::spawn(async move {
|
||||
let mut buf_hold = None;
|
||||
loop {
|
||||
let mut buf: PacketBuffer = buf_hold.take().unwrap_or_default();
|
||||
|
||||
select! {
|
||||
data = sink_receiver.recv() => {
|
||||
match data {
|
||||
None => return,
|
||||
Some(data) => {
|
||||
// We need to append a 4 byte header here
|
||||
if let Err(e) = tun.write_vectored(&[IoSlice::new(&HEADER), IoSlice::new(&data)]).await {
|
||||
error!("Failed to send data to tun interface {e}");
|
||||
}
|
||||
}
|
||||
}
|
||||
// Save the buffer as we didn't use it
|
||||
buf_hold = Some(buf);
|
||||
}
|
||||
read_result = tun.read(buf.buffer_mut()) => {
|
||||
let rr = read_result.map(|n| {
|
||||
buf.set_size(n);
|
||||
// Trim header
|
||||
buf.buffer_mut().copy_within(4.., 0);
|
||||
buf.set_size(n-4);
|
||||
buf
|
||||
});
|
||||
|
||||
|
||||
if tun_stream.send(rr).is_err() {
|
||||
error!("Could not forward data to tun stream, receiver is gone");
|
||||
break;
|
||||
};
|
||||
}
|
||||
}
|
||||
}
|
||||
info!("Stop reading from / writing to tun interface");
|
||||
});
|
||||
|
||||
Ok((
|
||||
tokio_stream::wrappers::UnboundedReceiverStream::new(stream_receiver),
|
||||
tokio_util::sync::PollSender::new(tun_sink),
|
||||
))
|
||||
}
|
||||
|
||||
/// Create a new TUN interface
|
||||
fn create_tun_interface(tun_fd: i32) -> Result<tun::AsyncDevice, Box<dyn std::error::Error>> {
|
||||
let mut config = tun::Configuration::default();
|
||||
config
|
||||
.layer(tun::Layer::L3)
|
||||
.mtu(LINK_MTU)
|
||||
.queues(1)
|
||||
.raw_fd(tun_fd)
|
||||
.up();
|
||||
let tun = tun::create_as_async(&config)?;
|
||||
|
||||
Ok(tun)
|
||||
}
|
||||
156
mycelium/src/tun/linux.rs
Normal file
156
mycelium/src/tun/linux.rs
Normal file
@@ -0,0 +1,156 @@
|
||||
//! Linux specific tun interface setup.
|
||||
|
||||
use std::io;
|
||||
|
||||
use futures::{Sink, Stream, TryStreamExt};
|
||||
use rtnetlink::Handle;
|
||||
use tokio::{select, sync::mpsc};
|
||||
use tokio_tun::{Tun, TunBuilder};
|
||||
use tracing::{error, info};
|
||||
|
||||
use crate::crypto::PacketBuffer;
|
||||
use crate::subnet::Subnet;
|
||||
use crate::tun::TunConfig;
|
||||
|
||||
// TODO
|
||||
const LINK_MTU: i32 = 1400;
|
||||
|
||||
/// Create a new tun interface and set required routes
|
||||
///
|
||||
/// # Panics
|
||||
///
|
||||
/// This function will panic if called outside of the context of a tokio runtime.
|
||||
pub async fn new(
|
||||
tun_config: TunConfig,
|
||||
) -> Result<
|
||||
(
|
||||
impl Stream<Item = io::Result<PacketBuffer>>,
|
||||
impl Sink<PacketBuffer, Error = impl std::error::Error> + Clone,
|
||||
),
|
||||
Box<dyn std::error::Error>,
|
||||
> {
|
||||
let tun = match create_tun_interface(&tun_config.name) {
|
||||
Ok(tun) => tun,
|
||||
Err(e) => {
|
||||
error!(
|
||||
"Could not create tun device named \"{}\", make sure the name is not yet in use, and you have sufficient privileges to create a network device",
|
||||
tun_config.name,
|
||||
);
|
||||
return Err(e);
|
||||
}
|
||||
};
|
||||
|
||||
let (conn, handle, _) = rtnetlink::new_connection()?;
|
||||
let netlink_task_handle = tokio::spawn(conn);
|
||||
|
||||
let tun_index = link_index_by_name(handle.clone(), tun_config.name).await?;
|
||||
|
||||
if let Err(e) = add_address(
|
||||
handle.clone(),
|
||||
tun_index,
|
||||
Subnet::new(
|
||||
tun_config.node_subnet.address(),
|
||||
tun_config.route_subnet.prefix_len(),
|
||||
)
|
||||
.unwrap(),
|
||||
)
|
||||
.await
|
||||
{
|
||||
error!(
|
||||
"Failed to add address {0} to TUN interface: {e}",
|
||||
tun_config.node_subnet
|
||||
);
|
||||
return Err(e);
|
||||
}
|
||||
|
||||
// We are done with our netlink connection, abort the task so we can properly clean up.
|
||||
netlink_task_handle.abort();
|
||||
|
||||
let (tun_sink, mut sink_receiver) = mpsc::channel::<PacketBuffer>(1000);
|
||||
let (tun_stream, stream_receiver) = mpsc::unbounded_channel();
|
||||
|
||||
// Spawn a single task to manage the TUN interface
|
||||
tokio::spawn(async move {
|
||||
let mut buf_hold = None;
|
||||
loop {
|
||||
let mut buf: PacketBuffer = buf_hold.take().unwrap_or_default();
|
||||
|
||||
select! {
|
||||
data = sink_receiver.recv() => {
|
||||
match data {
|
||||
None => return,
|
||||
Some(data) => {
|
||||
if let Err(e) = tun.send(&data).await {
|
||||
error!("Failed to send data to tun interface {e}");
|
||||
}
|
||||
}
|
||||
}
|
||||
// Save the buffer as we didn't use it
|
||||
buf_hold = Some(buf);
|
||||
}
|
||||
read_result = tun.recv(buf.buffer_mut()) => {
|
||||
let rr = read_result.map(|n| {
|
||||
buf.set_size(n);
|
||||
buf
|
||||
});
|
||||
|
||||
if tun_stream.send(rr).is_err() {
|
||||
error!("Could not forward data to tun stream, receiver is gone");
|
||||
break;
|
||||
};
|
||||
}
|
||||
}
|
||||
}
|
||||
info!("Stop reading from / writing to tun interface");
|
||||
});
|
||||
|
||||
Ok((
|
||||
tokio_stream::wrappers::UnboundedReceiverStream::new(stream_receiver),
|
||||
tokio_util::sync::PollSender::new(tun_sink),
|
||||
))
|
||||
}
|
||||
|
||||
/// Create a new TUN interface
|
||||
fn create_tun_interface(name: &str) -> Result<Tun, Box<dyn std::error::Error>> {
|
||||
let tun = TunBuilder::new()
|
||||
.name(name)
|
||||
.mtu(LINK_MTU)
|
||||
.queues(1)
|
||||
.up()
|
||||
.build()?
|
||||
.pop()
|
||||
.expect("Succesfully build tun interface has 1 queue");
|
||||
|
||||
Ok(tun)
|
||||
}
|
||||
|
||||
/// Retrieve the link index of an interface with the given name
|
||||
async fn link_index_by_name(
|
||||
handle: Handle,
|
||||
name: String,
|
||||
) -> Result<u32, Box<dyn std::error::Error>> {
|
||||
handle
|
||||
.link()
|
||||
.get()
|
||||
.match_name(name)
|
||||
.execute()
|
||||
.try_next()
|
||||
.await?
|
||||
.map(|link_message| link_message.header.index)
|
||||
.ok_or(io::Error::new(io::ErrorKind::NotFound, "link not found").into())
|
||||
}
|
||||
|
||||
/// Add an address to an interface.
|
||||
///
|
||||
/// The kernel will automatically add a route entry for the subnet assigned to the interface.
|
||||
async fn add_address(
|
||||
handle: Handle,
|
||||
link_index: u32,
|
||||
subnet: Subnet,
|
||||
) -> Result<(), Box<dyn std::error::Error>> {
|
||||
Ok(handle
|
||||
.address()
|
||||
.add(link_index, subnet.address(), subnet.prefix_len())
|
||||
.execute()
|
||||
.await?)
|
||||
}
|
||||
166
mycelium/src/tun/windows.rs
Normal file
166
mycelium/src/tun/windows.rs
Normal file
@@ -0,0 +1,166 @@
|
||||
use std::{io, ops::Deref, sync::Arc};
|
||||
|
||||
use futures::{Sink, Stream};
|
||||
use tokio::sync::mpsc;
|
||||
use tracing::{error, info, warn};
|
||||
|
||||
use crate::tun::TunConfig;
|
||||
use crate::{crypto::PacketBuffer, subnet::Subnet};
|
||||
|
||||
// TODO
|
||||
const LINK_MTU: usize = 1400;
|
||||
|
||||
/// Type of the tunnel used, specified when creating the tunnel.
|
||||
const WINDOWS_TUNNEL_TYPE: &str = "Mycelium";
|
||||
|
||||
pub async fn new(
|
||||
tun_config: TunConfig,
|
||||
) -> Result<
|
||||
(
|
||||
impl Stream<Item = io::Result<PacketBuffer>>,
|
||||
impl Sink<PacketBuffer, Error = impl std::error::Error> + Clone,
|
||||
),
|
||||
Box<dyn std::error::Error>,
|
||||
> {
|
||||
// SAFETY: for now we assume a valid wintun.dll file exists in the root directory when we are
|
||||
// running this.
|
||||
let wintun = unsafe { wintun::load() }?;
|
||||
let wintun_version = match wintun::get_running_driver_version(&wintun) {
|
||||
Ok(v) => format!("{v}"),
|
||||
Err(e) => {
|
||||
warn!("Failed to read wintun.dll version: {e}");
|
||||
"Unknown".to_string()
|
||||
}
|
||||
};
|
||||
info!("Loaded wintun.dll - running version {wintun_version}");
|
||||
let tun = wintun::Adapter::create(&wintun, &tun_config.name, WINDOWS_TUNNEL_TYPE, None)?;
|
||||
info!("Created wintun tunnel interface");
|
||||
// Configure created network adapter.
|
||||
set_adapter_mtu(&tun_config.name, LINK_MTU)?;
|
||||
// Set address, this will use a `netsh` command under the hood unfortunately.
|
||||
// TODO: fix in library
|
||||
// tun.set_network_addresses_tuple(node_subnet.address(), route_subnet.mask(), None)?;
|
||||
add_address(
|
||||
&tun_config.name,
|
||||
tun_config.node_subnet,
|
||||
tun_config.route_subnet,
|
||||
)?;
|
||||
// Build 2 separate sessions - one for receiving, one for sending.
|
||||
let rx_session = Arc::new(tun.start_session(wintun::MAX_RING_CAPACITY)?);
|
||||
let tx_session = rx_session.clone();
|
||||
|
||||
let (tun_sink, mut sink_receiver) = mpsc::channel::<PacketBuffer>(1000);
|
||||
let (tun_stream, stream_receiver) = mpsc::unbounded_channel();
|
||||
|
||||
// Ingress path
|
||||
tokio::task::spawn_blocking(move || {
|
||||
loop {
|
||||
let packet = rx_session
|
||||
.receive_blocking()
|
||||
.map(|tun_packet| {
|
||||
let mut buffer = PacketBuffer::new();
|
||||
// SAFETY: The configured MTU is smaller than the static PacketBuffer size.
|
||||
let packet_len = tun_packet.bytes().len();
|
||||
buffer.buffer_mut()[..packet_len].copy_from_slice(tun_packet.bytes());
|
||||
buffer.set_size(packet_len);
|
||||
buffer
|
||||
})
|
||||
.map_err(wintun_to_io_error);
|
||||
|
||||
if tun_stream.send(packet).is_err() {
|
||||
error!("Could not forward data to tun stream, receiver is gone");
|
||||
break;
|
||||
};
|
||||
}
|
||||
|
||||
info!("Stop reading from tun interface");
|
||||
});
|
||||
|
||||
// Egress path
|
||||
tokio::task::spawn_blocking(move || {
|
||||
loop {
|
||||
match sink_receiver.blocking_recv() {
|
||||
None => break,
|
||||
Some(data) => {
|
||||
let mut tun_packet =
|
||||
match tx_session.allocate_send_packet(data.deref().len() as u16) {
|
||||
Ok(tun_packet) => tun_packet,
|
||||
Err(e) => {
|
||||
error!("Could not allocate packet on TUN: {e}");
|
||||
break;
|
||||
}
|
||||
};
|
||||
// SAFETY: packet allocation is done on the length of &data.
|
||||
tun_packet.bytes_mut().copy_from_slice(&data);
|
||||
tx_session.send_packet(tun_packet);
|
||||
}
|
||||
}
|
||||
}
|
||||
info!("Stop writing to tun interface");
|
||||
});
|
||||
|
||||
Ok((
|
||||
tokio_stream::wrappers::UnboundedReceiverStream::new(stream_receiver),
|
||||
tokio_util::sync::PollSender::new(tun_sink),
|
||||
))
|
||||
}
|
||||
|
||||
/// Helper method to convert a [`wintun::Error`] to a [`std::io::Error`].
|
||||
fn wintun_to_io_error(err: wintun::Error) -> io::Error {
|
||||
match err {
|
||||
wintun::Error::Io(e) => e,
|
||||
_ => io::Error::other("unknown wintun error"),
|
||||
}
|
||||
}
|
||||
|
||||
/// Set an address on an interface by shelling out to `netsh`
|
||||
///
|
||||
/// We assume this is an IPv6 address.
|
||||
fn add_address(adapter_name: &str, subnet: Subnet, route_subnet: Subnet) -> Result<(), io::Error> {
|
||||
let exit_code = std::process::Command::new("netsh")
|
||||
.args([
|
||||
"interface",
|
||||
"ipv6",
|
||||
"set",
|
||||
"address",
|
||||
adapter_name,
|
||||
&format!("{}/{}", subnet.address(), route_subnet.prefix_len()),
|
||||
])
|
||||
.spawn()?
|
||||
.wait()?;
|
||||
|
||||
match exit_code.code() {
|
||||
Some(0) => Ok(()),
|
||||
Some(x) => Err(io::Error::from_raw_os_error(x)),
|
||||
None => {
|
||||
warn!("Failed to determine `netsh` exit status");
|
||||
Ok(())
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
fn set_adapter_mtu(name: &str, mtu: usize) -> Result<(), io::Error> {
|
||||
let args = &[
|
||||
"interface",
|
||||
"ipv6",
|
||||
"set",
|
||||
"subinterface",
|
||||
&format!("\"{name}\""),
|
||||
&format!("mtu={mtu}"),
|
||||
"store=persistent",
|
||||
];
|
||||
|
||||
let exit_code = std::process::Command::new("netsh")
|
||||
.args(args)
|
||||
.spawn()?
|
||||
.wait()?;
|
||||
|
||||
match exit_code.code() {
|
||||
Some(0) => Ok(()),
|
||||
Some(x) => Err(io::Error::from_raw_os_error(x)),
|
||||
None => {
|
||||
warn!("Failed to determine `netsh` exit status");
|
||||
Ok(())
|
||||
}
|
||||
}
|
||||
}
|
||||
Reference in New Issue
Block a user