Use single cached supervisorclient

Signed-off-by: Lee Smet <lee.smet@hotmail.com>
This commit is contained in:
Lee Smet
2025-09-05 17:09:57 +02:00
parent 2c88114d45
commit fb34b4e2f3

View File

@@ -1,9 +1,11 @@
use std::{collections::HashSet, sync::Arc};
use std::{collections::{HashSet, HashMap}, sync::Arc};
use base64::Engine;
use base64::engine::general_purpose::STANDARD as BASE64_STANDARD;
use serde_json::{Value, json};
use tokio::sync::Semaphore;
use tokio::sync::{Semaphore, Mutex};
use std::hash::{Hash, Hasher};
use std::collections::hash_map::DefaultHasher;
use crate::{
clients::{Destination, MyceliumClient, SupervisorClient},
@@ -23,6 +25,88 @@ pub struct RouterConfig {
pub transport_poll_timeout_secs: u64, // e.g. 300 (5 minutes)
}
/*
SupervisorClient reuse cache (Router-local):
Rationale:
- SupervisorClient maintains an internal JSON-RPC id_counter per instance.
- Rebuilding a client for each message resets this counter, causing inner JSON-RPC ids to restart at 1.
- We reuse one SupervisorClient per (destination, topic, secret) to preserve monotonically increasing ids.
Scope:
- Cache is per Router loop (and a separate one for the inbound listener).
- If cross-loop/process reuse becomes necessary later, promote to a process-global cache.
Keying:
- Key: destination + topic + secret-presence (secret content hashed; not stored in plaintext).
Concurrency:
- tokio::Mutex protects a HashMap<String, Arc<SupervisorClient>>.
- Values are Arc so call sites clone cheaply and share the same id_counter.
*/
#[derive(Clone)]
struct SupervisorClientCache {
map: Arc<Mutex<HashMap<String, Arc<SupervisorClient>>>>,
}
impl SupervisorClientCache {
fn new() -> Self {
Self {
map: Arc::new(Mutex::new(HashMap::new())),
}
}
fn make_key(dest: &Destination, topic: &str, secret: &Option<String>) -> String {
let dst = match dest {
Destination::Ip(ip) => format!("ip:{ip}"),
Destination::Pk(pk) => format!("pk:{pk}"),
};
// Hash the secret to avoid storing plaintext in keys while still differentiating values
let sec_hash = match secret {
Some(s) if !s.is_empty() => {
let mut hasher = DefaultHasher::new();
s.hash(&mut hasher);
format!("s:{}", hasher.finish())
}
_ => "s:none".to_string(),
};
format!("{dst}|t:{topic}|{sec_hash}")
}
async fn get_or_create(
&self,
mycelium: Arc<MyceliumClient>,
dest: Destination,
topic: String,
secret: Option<String>,
) -> Arc<SupervisorClient> {
let key = Self::make_key(&dest, &topic, &secret);
{
let guard = self.map.lock().await;
if let Some(existing) = guard.get(&key) {
tracing::debug!(target: "router", cache="supervisor", hit=true, %topic, secret = %if secret.is_some() { "set" } else { "none" }, "SupervisorClient cache lookup");
return existing.clone();
}
}
let mut guard = self.map.lock().await;
if let Some(existing) = guard.get(&key) {
tracing::debug!(target: "router", cache="supervisor", hit=true, %topic, secret = %if secret.is_some() { "set" } else { "none" }, "SupervisorClient cache lookup (double-checked)");
return existing.clone();
}
let client = Arc::new(SupervisorClient::new_with_client(
mycelium,
dest,
topic.clone(),
secret.clone(),
));
guard.insert(key, client.clone());
tracing::debug!(target: "router", cache="supervisor", hit=false, %topic, secret = %if secret.is_some() { "set" } else { "none" }, "SupervisorClient cache insert");
client
}
}
/// Start background router loops, one per context.
/// Each loop:
/// - BRPOP msg_out with 1s timeout
@@ -49,6 +133,8 @@ pub fn start_router(service: AppService, cfg: RouterConfig) -> Vec<tokio::task::
}
};
let cache = Arc::new(SupervisorClientCache::new());
loop {
// Pop next message key (blocking with timeout)
match service_cloned.brpop_msg_out(ctx_id, 1).await {
@@ -69,11 +155,12 @@ pub fn start_router(service: AppService, cfg: RouterConfig) -> Vec<tokio::task::
let cfg_task = cfg_cloned.clone();
tokio::spawn({
let mycelium = mycelium.clone();
let cache = cache.clone();
async move {
// Ensure permit is dropped at end of task
let _permit = permit;
if let Err(e) =
deliver_one(&service_task, &cfg_task, ctx_id, &key, mycelium)
deliver_one(&service_task, &cfg_task, ctx_id, &key, mycelium, cache.clone())
.await
{
error!(context_id=ctx_id, key=%key, error=%e, "Delivery error");
@@ -104,6 +191,7 @@ async fn deliver_one(
context_id: u32,
msg_key: &str,
mycelium: Arc<MyceliumClient>,
cache: Arc<SupervisorClientCache>,
) -> Result<(), Box<dyn std::error::Error + Send + Sync>> {
// Parse "message:{caller_id}:{id}"
let (caller_id, id) = parse_message_key(msg_key)
@@ -143,12 +231,14 @@ async fn deliver_one(
let dest_for_poller = dest.clone();
let topic_for_poller = cfg.topic.clone();
let secret_for_poller = runner.secret.clone();
let client = SupervisorClient::new_with_client(
mycelium.clone(),
dest.clone(),
cfg.topic.clone(),
runner.secret.clone(),
);
let client = cache
.get_or_create(
mycelium.clone(),
dest.clone(),
cfg.topic.clone(),
runner.secret.clone(),
)
.await;
// Build supervisor method and params from Message
let method = msg.message.clone();
@@ -251,12 +341,14 @@ async fn deliver_one(
// if matches!(s, TransportStatus::Read)
// && let Some(job_id) = job_id_opt
if let Some(job_id) = job_id_opt {
let sup = SupervisorClient::new_with_client(
client.clone(),
sup_dest.clone(),
sup_topic.clone(),
secret_for_poller.clone(),
);
let sup = cache
.get_or_create(
client.clone(),
sup_dest.clone(),
sup_topic.clone(),
secret_for_poller.clone(),
)
.await;
match sup.job_status_with_ids(job_id.to_string()).await {
Ok((_out_id, inner_id)) => {
// Correlate this status request to the message/job
@@ -425,6 +517,8 @@ pub fn start_inbound_listener(
}
};
let cache = Arc::new(SupervisorClientCache::new());
loop {
// Poll for inbound supervisor messages on the configured topic
match mycelium.pop_message(Some(false), Some(20), None).await {
@@ -599,12 +693,14 @@ pub fn start_inbound_listener(
} else {
Destination::Ip(runner.address)
};
let sup = SupervisorClient::new_with_client(
mycelium.clone(),
dest,
cfg.topic.clone(),
runner.secret.clone(),
);
let sup = cache
.get_or_create(
mycelium.clone(),
dest,
cfg.topic.clone(),
runner.secret.clone(),
)
.await;
match sup
.job_result_with_ids(
job_id.to_string(),