use std::{net::IpAddr, ops::Deref, time::Duration}; use axum::{ extract::{Path, Query, State}, http::StatusCode, routing::{delete, get, post}, Json, Router, }; use serde::{Deserialize, Serialize}; use tracing::debug; use mycelium::{ crypto::PublicKey, message::{MessageId, MessageInfo}, metrics::Metrics, subnet::Subnet, }; use std::path::PathBuf; use super::ServerState; /// Default amount of time to try and send a message if it is not explicitly specified. const DEFAULT_MESSAGE_TRY_DURATION: Duration = Duration::from_secs(60 * 5); /// Return a router which has message endpoints and their handlers mounted. pub fn message_router_v1(server_state: ServerState) -> Router where M: Metrics + Clone + Send + Sync + 'static, { Router::new() .route("/messages", get(get_message).post(push_message)) .route("/messages/status/{id}", get(message_status)) .route("/messages/reply/{id}", post(reply_message)) // Topic configuration endpoints .route( "/messages/topics/default", get(get_default_topic_action).put(set_default_topic_action), ) .route("/messages/topics", get(get_topics).post(add_topic)) .route("/messages/topics/{topic}", delete(remove_topic)) .route( "/messages/topics/{topic}/sources", get(get_topic_sources).post(add_topic_source), ) .route( "/messages/topics/{topic}/sources/{subnet}", delete(remove_topic_source), ) .route( "/messages/topics/{topic}/forward", get(get_topic_forward_socket) .put(set_topic_forward_socket) .delete(remove_topic_forward_socket), ) .with_state(server_state) } #[derive(Clone, Debug, Serialize, Deserialize)] #[serde(rename_all = "camelCase")] pub struct MessageSendInfo { pub dst: MessageDestination, #[serde(default)] #[serde(skip_serializing_if = "Option::is_none")] #[serde(with = "base64::optional_binary")] pub topic: Option>, #[serde(with = "base64::binary")] pub payload: Vec, } #[derive(Clone, Debug, Serialize, Deserialize)] #[serde(rename_all = "camelCase")] pub enum MessageDestination { Ip(IpAddr), Pk(PublicKey), } #[derive(Clone, Deserialize, Serialize)] #[serde(rename_all = "camelCase")] pub struct MessageReceiveInfo { pub id: MessageId, pub src_ip: IpAddr, pub src_pk: PublicKey, pub dst_ip: IpAddr, pub dst_pk: PublicKey, #[serde(default)] #[serde(skip_serializing_if = "Option::is_none")] #[serde(with = "base64::optional_binary")] pub topic: Option>, #[serde(with = "base64::binary")] pub payload: Vec, } impl MessageDestination { /// Get the IP address of the destination. fn ip(self) -> IpAddr { match self { MessageDestination::Ip(ip) => ip, MessageDestination::Pk(pk) => IpAddr::V6(pk.address()), } } } #[derive(Deserialize)] struct GetMessageQuery { peek: Option, timeout: Option, /// Optional filter for start of the message, base64 encoded. #[serde(default)] #[serde(skip_serializing_if = "Option::is_none")] #[serde(with = "base64::optional_binary")] topic: Option>, } impl GetMessageQuery { /// Did the query indicate we should peek the message instead of pop? fn peek(&self) -> bool { matches!(self.peek, Some(true)) } /// Amount of seconds to hold and try and get values. fn timeout_secs(&self) -> u64 { self.timeout.unwrap_or(0) } } async fn get_message( State(state): State>, Query(query): Query, ) -> Result, StatusCode> where M: Metrics + Clone + Send + Sync + 'static, { debug!( "Attempt to get message, peek {}, timeout {} seconds", query.peek(), query.timeout_secs() ); // A timeout of 0 seconds essentially means get a message if there is one, and return // immediatly if there isn't. This is the result of the implementation of Timeout, which does a // poll of the internal future first, before polling the delay. tokio::time::timeout( Duration::from_secs(query.timeout_secs()), state .node .lock() .await .get_message(!query.peek(), query.topic), ) .await .or(Err(StatusCode::NO_CONTENT)) .map(|m| { Json(MessageReceiveInfo { id: m.id, src_ip: m.src_ip, src_pk: m.src_pk, dst_ip: m.dst_ip, dst_pk: m.dst_pk, topic: if m.topic.is_empty() { None } else { Some(m.topic) }, payload: m.data, }) }) } #[derive(Clone, Deserialize, Serialize)] #[serde(rename_all = "camelCase")] pub struct MessageIdReply { pub id: MessageId, } #[derive(Clone, Deserialize, Serialize)] #[serde(rename_all = "camelCase")] #[serde(untagged)] pub enum PushMessageResponse { Reply(MessageReceiveInfo), Id(MessageIdReply), } #[derive(Clone, Deserialize)] struct PushMessageQuery { reply_timeout: Option, } impl PushMessageQuery { /// The user requested to wait for the reply or not. fn await_reply(&self) -> bool { self.reply_timeout.is_some() } /// Amount of seconds to wait for the reply. fn timeout(&self) -> u64 { self.reply_timeout.unwrap_or(0) } } async fn push_message( State(state): State>, Query(query): Query, Json(message_info): Json, ) -> Result<(StatusCode, Json), StatusCode> where M: Metrics + Clone + Send + Sync + 'static, { let dst = message_info.dst.ip(); debug!( message.dst=%dst, message.len=message_info.payload.len(), "Pushing new message to stack", ); let (id, sub) = match state.node.lock().await.push_message( dst, message_info.payload, message_info.topic, DEFAULT_MESSAGE_TRY_DURATION, query.await_reply(), ) { Ok((id, sub)) => (id, sub), Err(_) => { return Err(StatusCode::BAD_REQUEST); } }; if !query.await_reply() { // If we don't wait for the reply just return here. return Ok(( StatusCode::CREATED, Json(PushMessageResponse::Id(MessageIdReply { id })), )); } let mut sub = sub.unwrap(); tokio::select! { sub_res = sub.changed() => { match sub_res { Ok(_) => { if let Some(m) = sub.borrow().deref() { Ok((StatusCode::OK, Json(PushMessageResponse::Reply(MessageReceiveInfo { id: m.id, src_ip: m.src_ip, src_pk: m.src_pk, dst_ip: m.dst_ip, dst_pk: m.dst_pk, topic: if m.topic.is_empty() { None } else { Some(m.topic.clone()) }, payload: m.data.clone(), })))) } else { // This happens if a none value is send, which should not happen. Err(StatusCode::INTERNAL_SERVER_ERROR) } } Err(_) => { // This happens if the sender drops, which should not happen. Err(StatusCode::INTERNAL_SERVER_ERROR) } } }, _ = tokio::time::sleep(Duration::from_secs(query.timeout())) => { // Timeout expired while waiting for reply Ok((StatusCode::REQUEST_TIMEOUT, Json(PushMessageResponse::Id(MessageIdReply { id })))) } } } async fn reply_message( State(state): State>, Path(id): Path, Json(message_info): Json, ) -> StatusCode where M: Metrics + Clone + Send + Sync + 'static, { let dst = message_info.dst.ip(); debug!( message.id=id.as_hex(), message.dst=%dst, message.len=message_info.payload.len(), "Pushing new reply to message stack", ); state.node.lock().await.reply_message( id, dst, message_info.payload, DEFAULT_MESSAGE_TRY_DURATION, ); StatusCode::NO_CONTENT } async fn message_status( State(state): State>, Path(id): Path, ) -> Result, StatusCode> where M: Metrics + Clone + Send + Sync + 'static, { debug!(message.id=%id.as_hex(), "Fetching message status"); state .node .lock() .await .message_status(id) .ok_or(StatusCode::NOT_FOUND) .map(Json) } /// Module to implement base64 decoding and encoding /// Sourced from https://users.rust-lang.org/t/serialize-a-vec-u8-to-json-as-base64/57781, with some /// addaptions to work with the new version of the base64 crate mod base64 { use base64::engine::{GeneralPurpose, GeneralPurposeConfig}; use base64::{alphabet, Engine}; const B64ENGINE: GeneralPurpose = base64::engine::general_purpose::GeneralPurpose::new( &alphabet::STANDARD, GeneralPurposeConfig::new(), ); pub fn encode(input: &[u8]) -> String { B64ENGINE.encode(input) } pub fn decode(input: &[u8]) -> Result, base64::DecodeError> { B64ENGINE.decode(input) } pub mod binary { use super::B64ENGINE; use base64::Engine; use serde::{Deserialize, Serialize}; use serde::{Deserializer, Serializer}; pub fn serialize(v: &Vec, s: S) -> Result { let base64 = B64ENGINE.encode(v); String::serialize(&base64, s) } pub fn deserialize<'de, D: Deserializer<'de>>(d: D) -> Result, D::Error> { let base64 = String::deserialize(d)?; B64ENGINE .decode(base64.as_bytes()) .map_err(serde::de::Error::custom) } } pub mod optional_binary { use super::B64ENGINE; use base64::Engine; use serde::{Deserialize, Serialize}; use serde::{Deserializer, Serializer}; pub fn serialize(v: &Option>, s: S) -> Result { if let Some(v) = v { let base64 = B64ENGINE.encode(v); String::serialize(&base64, s) } else { >::serialize(&None, s) } } pub fn deserialize<'de, D: Deserializer<'de>>(d: D) -> Result>, D::Error> { if let Some(base64) = >::deserialize(d)? { B64ENGINE .decode(base64.as_bytes()) .map_err(serde::de::Error::custom) .map(Option::Some) } else { Ok(None) } } } } // Topic configuration API /// Response for the default topic action #[derive(Serialize, Deserialize)] #[serde(rename_all = "camelCase")] struct DefaultTopicActionResponse { accept: bool, } /// Request to set the default topic action #[derive(Serialize, Deserialize)] #[serde(rename_all = "camelCase")] struct DefaultTopicActionRequest { accept: bool, } /// Request to add a source to a topic whitelist #[derive(Serialize, Deserialize)] #[serde(rename_all = "camelCase")] struct TopicSourceRequest { subnet: String, } /// Request to set a forward socket for a topic #[derive(Serialize, Deserialize)] #[serde(rename_all = "camelCase")] struct TopicForwardSocketRequest { socket_path: String, } /// Get the default topic action (accept or reject) async fn get_default_topic_action( State(state): State>, ) -> Json where M: Metrics + Clone + Send + Sync + 'static, { debug!("Getting default topic action"); let accept = state.node.lock().await.unconfigure_topic_action(); Json(DefaultTopicActionResponse { accept }) } /// Set the default topic action (accept or reject) async fn set_default_topic_action( State(state): State>, Json(request): Json, ) -> StatusCode where M: Metrics + Clone + Send + Sync + 'static, { debug!(accept=%request.accept, "Setting default topic action"); state .node .lock() .await .accept_unconfigured_topic(request.accept); StatusCode::NO_CONTENT } /// Get all whitelisted topics async fn get_topics(State(state): State>) -> Json> where M: Metrics + Clone + Send + Sync + 'static, { debug!("Getting all whitelisted topics"); let node = state.node.lock().await; // Get the whitelist from the node let topics = node.topics(); // Convert to TopicInfo structs let topics: Vec = topics.iter().map(|topic| base64::encode(topic)).collect(); Json(topics) } /// Add a topic to the whitelist async fn add_topic( State(state): State>, Json(topic_info): Json>, ) -> StatusCode where M: Metrics + Clone + Send + Sync + 'static, { debug!("Adding topic to whitelist"); state.node.lock().await.add_topic_whitelist(topic_info); StatusCode::CREATED } /// Remove a topic from the whitelist async fn remove_topic( State(state): State>, Path(topic): Path, ) -> Result where M: Metrics + Clone + Send + Sync + 'static, { debug!("Removing topic from whitelist"); // Decode the base64 topic let topic_bytes = match base64::decode(topic.as_bytes()) { Ok(bytes) => bytes, Err(_) => return Err(StatusCode::BAD_REQUEST), }; state.node.lock().await.remove_topic_whitelist(topic_bytes); Ok(StatusCode::NO_CONTENT) } /// Get all sources for a topic async fn get_topic_sources( State(state): State>, Path(topic): Path, ) -> Result>, StatusCode> where M: Metrics + Clone + Send + Sync + 'static, { debug!("Getting sources for topic"); // Decode the base64 topic let topic_bytes = match base64::decode(topic.as_bytes()) { Ok(bytes) => bytes, Err(_) => return Err(StatusCode::BAD_REQUEST), }; let node = state.node.lock().await; // Get the whitelist from the node let sources = node.topic_allowed_sources(&topic_bytes); // Find the topic in the whitelist if let Some(sources) = sources { let sources = sources.into_iter().map(|s| s.to_string()).collect(); Ok(Json(sources)) } else { Err(StatusCode::NOT_FOUND) } } /// Add a source to a topic whitelist async fn add_topic_source( State(state): State>, Path(topic): Path, Json(request): Json, ) -> Result where M: Metrics + Clone + Send + Sync + 'static, { debug!("Adding source to topic whitelist"); // Decode the base64 topic let topic_bytes = match base64::decode(topic.as_bytes()) { Ok(bytes) => bytes, Err(_) => return Err(StatusCode::BAD_REQUEST), }; // Parse the subnet let subnet = match request.subnet.parse::() { Ok(subnet) => subnet, Err(_) => return Err(StatusCode::BAD_REQUEST), }; state .node .lock() .await .add_topic_whitelist_src(topic_bytes, subnet); Ok(StatusCode::CREATED) } /// Remove a source from a topic whitelist async fn remove_topic_source( State(state): State>, Path((topic, subnet_str)): Path<(String, String)>, ) -> Result where M: Metrics + Clone + Send + Sync + 'static, { debug!("Removing source from topic whitelist"); // Decode the base64 topic let topic_bytes = match base64::decode(topic.as_bytes()) { Ok(bytes) => bytes, Err(_) => return Err(StatusCode::BAD_REQUEST), }; // Parse the subnet let subnet = match subnet_str.parse::() { Ok(subnet) => subnet, Err(_) => return Err(StatusCode::BAD_REQUEST), }; state .node .lock() .await .remove_topic_whitelist_src(topic_bytes, subnet); Ok(StatusCode::NO_CONTENT) } /// Get the forward socket for a topic async fn get_topic_forward_socket( State(state): State>, Path(topic): Path, ) -> Result>, StatusCode> where M: Metrics + Clone + Send + Sync + 'static, { debug!("Getting forward socket for topic"); // Decode the base64 topic let topic_bytes = match base64::decode(topic.as_bytes()) { Ok(bytes) => bytes, Err(_) => return Err(StatusCode::BAD_REQUEST), }; let node = state.node.lock().await; let socket_path = node .get_topic_forward_socket(&topic_bytes) .map(|p| p.to_string_lossy().to_string()); Ok(Json(socket_path)) } /// Set the forward socket for a topic async fn set_topic_forward_socket( State(state): State>, Path(topic): Path, Json(request): Json, ) -> Result where M: Metrics + Clone + Send + Sync + 'static, { debug!("Setting forward socket for topic"); // Decode the base64 topic let topic_bytes = match base64::decode(topic.as_bytes()) { Ok(bytes) => bytes, Err(_) => return Err(StatusCode::BAD_REQUEST), }; let socket_path = PathBuf::from(request.socket_path); state .node .lock() .await .set_topic_forward_socket(topic_bytes, socket_path); Ok(StatusCode::NO_CONTENT) } /// Remove the forward socket for a topic async fn remove_topic_forward_socket( State(state): State>, Path(topic): Path, ) -> Result where M: Metrics + Clone + Send + Sync + 'static, { debug!("Removing forward socket for topic"); // Decode the base64 topic let topic_bytes = match base64::decode(topic.as_bytes()) { Ok(bytes) => bytes, Err(_) => return Err(StatusCode::BAD_REQUEST), }; state .node .lock() .await .delete_topic_forward_socket(topic_bytes); Ok(StatusCode::NO_CONTENT) }