995 lines
43 KiB
Rust
995 lines
43 KiB
Rust
use futures_channel::{mpsc, oneshot};
|
||
use futures_util::{FutureExt, SinkExt, StreamExt};
|
||
use log::{debug, error, info, warn};
|
||
use serde::{Deserialize, Serialize};
|
||
use serde_json::Value;
|
||
use std::collections::HashMap;
|
||
use std::sync::{Arc, Mutex};
|
||
use thiserror::Error;
|
||
use uuid::Uuid;
|
||
|
||
// Authentication module
|
||
pub mod auth;
|
||
|
||
pub use auth::{AuthCredentials, AuthError, AuthResult};
|
||
|
||
// Platform-specific WebSocket imports and spawn function
|
||
#[cfg(target_arch = "wasm32")]
|
||
use {
|
||
gloo_net::websocket::{futures::WebSocket, Message as GlooWsMessage},
|
||
wasm_bindgen_futures::spawn_local,
|
||
};
|
||
|
||
#[cfg(not(target_arch = "wasm32"))]
|
||
use {
|
||
tokio::spawn as spawn_local,
|
||
tokio_tungstenite::{
|
||
connect_async, connect_async_tls_with_config,
|
||
tungstenite::{
|
||
protocol::Message as TungsteniteWsMessage,
|
||
},
|
||
Connector,
|
||
},
|
||
};
|
||
|
||
// JSON-RPC Structures (client-side perspective)
|
||
#[derive(Serialize, Debug, Clone)]
|
||
pub struct JsonRpcRequestClient {
|
||
jsonrpc: String,
|
||
method: String,
|
||
params: Value,
|
||
id: String,
|
||
}
|
||
|
||
#[derive(Deserialize, Debug, Clone)]
|
||
pub struct JsonRpcResponseClient {
|
||
#[allow(dead_code)]
|
||
// Field is part of JSON-RPC spec, even if not directly used by client logic
|
||
jsonrpc: String,
|
||
pub result: Option<Value>,
|
||
pub error: Option<JsonRpcErrorClient>,
|
||
pub id: String,
|
||
}
|
||
|
||
#[derive(Deserialize, Debug, Clone)]
|
||
pub struct JsonRpcErrorClient {
|
||
pub code: i32,
|
||
pub message: String,
|
||
pub data: Option<Value>,
|
||
}
|
||
|
||
#[derive(Serialize, Debug, Clone)]
|
||
pub struct PlayParamsClient {
|
||
pub script: String,
|
||
}
|
||
|
||
#[derive(Deserialize, Debug, Clone)]
|
||
pub struct PlayResultClient {
|
||
pub output: String,
|
||
}
|
||
|
||
#[derive(Serialize, Debug, Clone)]
|
||
pub struct AuthCredentialsParams {
|
||
pub pubkey: String,
|
||
pub signature: String,
|
||
}
|
||
|
||
#[derive(Serialize, Debug, Clone)]
|
||
pub struct FetchNonceParams {
|
||
pub pubkey: String,
|
||
}
|
||
|
||
#[derive(Deserialize, Debug, Clone)]
|
||
pub struct FetchNonceResponse {
|
||
pub nonce: String,
|
||
}
|
||
|
||
#[derive(Error, Debug)]
|
||
pub enum CircleWsClientError {
|
||
#[error("WebSocket connection error: {0}")]
|
||
ConnectionError(String),
|
||
#[error("WebSocket send error: {0}")]
|
||
SendError(String),
|
||
#[error("WebSocket receive error: {0}")]
|
||
ReceiveError(String),
|
||
#[error("JSON serialization/deserialization error: {0}")]
|
||
JsonError(#[from] serde_json::Error),
|
||
#[error("Request timed out for request ID: {0}")]
|
||
Timeout(String),
|
||
#[error("JSON-RPC error response: {code} - {message}")]
|
||
JsonRpcError {
|
||
code: i32,
|
||
message: String,
|
||
data: Option<Value>,
|
||
},
|
||
#[error("No response received for request ID: {0}")]
|
||
NoResponse(String),
|
||
#[error("Client is not connected")]
|
||
NotConnected,
|
||
#[error("Internal channel error: {0}")]
|
||
ChannelError(String),
|
||
#[error("Authentication error: {0}")]
|
||
Auth(#[from] auth::AuthError),
|
||
#[error("Authentication requires a keypair, but none was provided.")]
|
||
AuthNoKeyPair,
|
||
}
|
||
|
||
// Wrapper for messages sent to the WebSocket task
|
||
enum InternalWsMessage {
|
||
SendJsonRpc(
|
||
JsonRpcRequestClient,
|
||
oneshot::Sender<Result<JsonRpcResponseClient, CircleWsClientError>>,
|
||
),
|
||
SendPlaintext(
|
||
String,
|
||
oneshot::Sender<Result<String, CircleWsClientError>>,
|
||
),
|
||
Close,
|
||
}
|
||
|
||
pub struct CircleWsClientBuilder {
|
||
ws_url: String,
|
||
private_key: Option<String>,
|
||
}
|
||
|
||
impl CircleWsClientBuilder {
|
||
pub fn new(ws_url: String) -> Self {
|
||
Self {
|
||
ws_url,
|
||
private_key: None,
|
||
}
|
||
}
|
||
|
||
pub fn with_keypair(mut self, private_key: String) -> Self {
|
||
self.private_key = Some(private_key);
|
||
self
|
||
}
|
||
|
||
pub fn build(self) -> CircleWsClient {
|
||
CircleWsClient {
|
||
ws_url: self.ws_url,
|
||
internal_tx: None,
|
||
#[cfg(not(target_arch = "wasm32"))]
|
||
task_handle: None,
|
||
private_key: self.private_key,
|
||
is_connected: Arc::new(Mutex::new(false)),
|
||
}
|
||
}
|
||
}
|
||
|
||
pub struct CircleWsClient {
|
||
ws_url: String,
|
||
internal_tx: Option<mpsc::Sender<InternalWsMessage>>,
|
||
#[cfg(not(target_arch = "wasm32"))]
|
||
task_handle: Option<tokio::task::JoinHandle<()>>,
|
||
private_key: Option<String>,
|
||
is_connected: Arc<Mutex<bool>>,
|
||
}
|
||
|
||
impl CircleWsClient {
|
||
/// Get the connection status
|
||
pub fn get_connection_status(&self) -> String {
|
||
if *self.is_connected.lock().unwrap() {
|
||
"Connected".to_string()
|
||
} else {
|
||
"Disconnected".to_string()
|
||
}
|
||
}
|
||
|
||
/// Check if the client is connected
|
||
pub fn is_connected(&self) -> bool {
|
||
*self.is_connected.lock().unwrap()
|
||
}
|
||
}
|
||
|
||
impl CircleWsClient {
|
||
pub async fn authenticate(&mut self) -> Result<bool, CircleWsClientError> {
|
||
info!("🔐 [{}] Starting authentication process...", self.ws_url);
|
||
|
||
let private_key = self
|
||
.private_key
|
||
.as_ref()
|
||
.ok_or(CircleWsClientError::AuthNoKeyPair)?;
|
||
|
||
info!("🔑 [{}] Deriving public key from private key...", self.ws_url);
|
||
let public_key = auth::derive_public_key(private_key)?;
|
||
info!("✅ [{}] Public key derived: {}...", self.ws_url, &public_key[..8]);
|
||
|
||
info!("🎫 [{}] Fetching authentication nonce...", self.ws_url);
|
||
let nonce = self.fetch_nonce(&public_key).await?;
|
||
info!("✅ [{}] Nonce received: {}...", self.ws_url, &nonce[..8]);
|
||
|
||
info!("✍️ [{}] Signing nonce with private key...", self.ws_url);
|
||
let signature = auth::sign_message(private_key, &nonce)?;
|
||
info!("✅ [{}] Signature created: {}...", self.ws_url, &signature[..8]);
|
||
|
||
info!("🔒 [{}] Submitting authentication credentials...", self.ws_url);
|
||
let result = self.authenticate_with_signature(&public_key, &signature).await?;
|
||
|
||
if result {
|
||
info!("🎉 [{}] Authentication successful!", self.ws_url);
|
||
} else {
|
||
error!("❌ [{}] Authentication failed - server rejected credentials", self.ws_url);
|
||
}
|
||
|
||
Ok(result)
|
||
}
|
||
|
||
async fn fetch_nonce(&self, pubkey: &str) -> Result<String, CircleWsClientError> {
|
||
info!("📡 [{}] Sending fetch_nonce request for pubkey: {}...", self.ws_url, &pubkey[..8]);
|
||
|
||
let params = FetchNonceParams {
|
||
pubkey: pubkey.to_string(),
|
||
};
|
||
let req = self.create_request("fetch_nonce", params)?;
|
||
let res = self.send_request(req).await?;
|
||
|
||
if let Some(err) = res.error {
|
||
error!("❌ [{}] fetch_nonce failed: {} (code: {})", self.ws_url, err.message, err.code);
|
||
return Err(CircleWsClientError::JsonRpcError {
|
||
code: err.code,
|
||
message: err.message,
|
||
data: err.data,
|
||
});
|
||
}
|
||
|
||
let nonce_res: FetchNonceResponse = serde_json::from_value(res.result.unwrap_or_default())?;
|
||
info!("✅ [{}] fetch_nonce successful, nonce length: {}", self.ws_url, nonce_res.nonce.len());
|
||
Ok(nonce_res.nonce)
|
||
}
|
||
|
||
async fn authenticate_with_signature(
|
||
&self,
|
||
pubkey: &str,
|
||
signature: &str,
|
||
) -> Result<bool, CircleWsClientError> {
|
||
info!("📡 [{}] Sending authenticate request with signature...", self.ws_url);
|
||
|
||
let params = AuthCredentialsParams {
|
||
pubkey: pubkey.to_string(),
|
||
signature: signature.to_string(),
|
||
};
|
||
let req = self.create_request("authenticate", params)?;
|
||
let res = self.send_request(req).await?;
|
||
|
||
if let Some(err) = res.error {
|
||
error!("❌ [{}] authenticate failed: {} (code: {})", self.ws_url, err.message, err.code);
|
||
return Err(CircleWsClientError::JsonRpcError {
|
||
code: err.code,
|
||
message: err.message,
|
||
data: err.data,
|
||
});
|
||
}
|
||
|
||
let authenticated = res
|
||
.result
|
||
.and_then(|v| v.get("authenticated").and_then(|v| v.as_bool()))
|
||
.unwrap_or(false);
|
||
|
||
if authenticated {
|
||
info!("✅ [{}] authenticate request successful - server confirmed authentication", self.ws_url);
|
||
} else {
|
||
error!("❌ [{}] authenticate request failed - server returned false", self.ws_url);
|
||
}
|
||
|
||
Ok(authenticated)
|
||
}
|
||
|
||
/// Call the whoami method to get authentication status and user information
|
||
pub async fn whoami(&self) -> Result<Value, CircleWsClientError> {
|
||
let req = self.create_request("whoami", serde_json::json!({}))?;
|
||
let response = self.send_request(req).await?;
|
||
|
||
if let Some(result) = response.result {
|
||
Ok(result)
|
||
} else if let Some(error) = response.error {
|
||
Err(CircleWsClientError::JsonRpcError {
|
||
code: error.code,
|
||
message: error.message,
|
||
data: error.data,
|
||
})
|
||
} else {
|
||
Err(CircleWsClientError::NoResponse("whoami".to_string()))
|
||
}
|
||
}
|
||
|
||
fn create_request<T: Serialize>(
|
||
&self,
|
||
method: &str,
|
||
params: T,
|
||
) -> Result<JsonRpcRequestClient, CircleWsClientError> {
|
||
Ok(JsonRpcRequestClient {
|
||
jsonrpc: "2.0".to_string(),
|
||
method: method.to_string(),
|
||
params: serde_json::to_value(params)?,
|
||
id: Uuid::new_v4().to_string(),
|
||
})
|
||
}
|
||
|
||
async fn send_request(
|
||
&self,
|
||
req: JsonRpcRequestClient,
|
||
) -> Result<JsonRpcResponseClient, CircleWsClientError> {
|
||
let (response_tx, response_rx) = oneshot::channel();
|
||
if let Some(mut tx) = self.internal_tx.clone() {
|
||
tx.send(InternalWsMessage::SendJsonRpc(req.clone(), response_tx))
|
||
.await
|
||
.map_err(|e| {
|
||
CircleWsClientError::ChannelError(format!(
|
||
"Failed to send request to internal task: {}",
|
||
e
|
||
))
|
||
})?;
|
||
} else {
|
||
return Err(CircleWsClientError::NotConnected);
|
||
}
|
||
|
||
#[cfg(target_arch = "wasm32")]
|
||
{
|
||
match response_rx.await {
|
||
Ok(Ok(rpc_response)) => Ok(rpc_response),
|
||
Ok(Err(e)) => Err(e),
|
||
Err(_) => Err(CircleWsClientError::Timeout(req.id)),
|
||
}
|
||
}
|
||
#[cfg(not(target_arch = "wasm32"))]
|
||
{
|
||
use tokio::time::timeout as tokio_timeout;
|
||
match tokio_timeout(std::time::Duration::from_secs(30), response_rx).await {
|
||
Ok(Ok(Ok(rpc_response))) => Ok(rpc_response),
|
||
Ok(Ok(Err(e))) => Err(e),
|
||
Ok(Err(_)) => Err(CircleWsClientError::ChannelError(
|
||
"Response channel cancelled".to_string(),
|
||
)),
|
||
Err(_) => Err(CircleWsClientError::Timeout(req.id)),
|
||
}
|
||
}
|
||
}
|
||
|
||
pub async fn connect(&mut self) -> Result<(), CircleWsClientError> {
|
||
if self.internal_tx.is_some() {
|
||
info!("🔄 [{}] Client already connected or connecting", self.ws_url);
|
||
return Ok(());
|
||
}
|
||
|
||
info!("🚀 [{}] Starting self-managed WebSocket connection with keep-alive and reconnection...", self.ws_url);
|
||
let (internal_tx, internal_rx) = mpsc::channel::<InternalWsMessage>(32);
|
||
self.internal_tx = Some(internal_tx);
|
||
|
||
// Clone necessary data for the task
|
||
let connection_url = self.ws_url.clone();
|
||
let private_key = self.private_key.clone();
|
||
let is_connected = self.is_connected.clone();
|
||
info!("🔗 [{}] Will handle connection, authentication, keep-alive, and reconnection internally", connection_url);
|
||
|
||
// Pending requests: map request_id to a oneshot sender for the response
|
||
let pending_requests: Arc<
|
||
Mutex<
|
||
HashMap<
|
||
String,
|
||
oneshot::Sender<Result<JsonRpcResponseClient, CircleWsClientError>>,
|
||
>,
|
||
>,
|
||
> = Arc::new(Mutex::new(HashMap::new()));
|
||
|
||
let task_pending_requests = pending_requests.clone();
|
||
let log_url = connection_url.clone();
|
||
|
||
let task = async move {
|
||
// Main connection loop with reconnection logic
|
||
loop {
|
||
info!("🔄 [{}] Starting connection attempt...", log_url);
|
||
|
||
// Reset connection status
|
||
*is_connected.lock().unwrap() = false;
|
||
|
||
// Clone connection_url for this iteration to avoid move issues
|
||
let connection_url_clone = connection_url.clone();
|
||
|
||
// Establish WebSocket connection
|
||
#[cfg(target_arch = "wasm32")]
|
||
let ws_result = WebSocket::open(&connection_url_clone);
|
||
|
||
#[cfg(not(target_arch = "wasm32"))]
|
||
let connect_attempt = async {
|
||
// Check if this is a secure WebSocket connection
|
||
if connection_url_clone.starts_with("wss://") {
|
||
// For WSS connections, use a custom TLS connector that accepts self-signed certificates
|
||
// This is for development/demo purposes only
|
||
use tokio_tungstenite::tungstenite::client::IntoClientRequest;
|
||
|
||
let request = connection_url_clone.as_str().into_client_request()
|
||
.map_err(|e| CircleWsClientError::ConnectionError(format!("Invalid URL: {}", e)))?;
|
||
|
||
// Create a native-tls connector that accepts invalid certificates (for development)
|
||
let tls_connector = native_tls::TlsConnector::builder()
|
||
.danger_accept_invalid_certs(true)
|
||
.danger_accept_invalid_hostnames(true)
|
||
.build()
|
||
.map_err(|e| CircleWsClientError::ConnectionError(format!("TLS connector creation failed: {}", e)))?;
|
||
|
||
let connector = Connector::NativeTls(tls_connector);
|
||
|
||
warn!("⚠️ DEVELOPMENT MODE: Accepting self-signed certificates (NOT for production!)");
|
||
connect_async_tls_with_config(request, None, false, Some(connector))
|
||
.await
|
||
.map_err(|e| CircleWsClientError::ConnectionError(format!("WSS connection failed: {}", e)))
|
||
} else {
|
||
// For regular WS connections, use the standard method
|
||
connect_async(&connection_url_clone)
|
||
.await
|
||
.map_err(|e| CircleWsClientError::ConnectionError(format!("WS connection failed: {}", e)))
|
||
}
|
||
};
|
||
#[cfg(not(target_arch = "wasm32"))]
|
||
let ws_result = connect_attempt.await;
|
||
|
||
match ws_result {
|
||
Ok(ws_conn_maybe_response) => {
|
||
#[cfg(target_arch = "wasm32")]
|
||
let ws_conn = ws_conn_maybe_response;
|
||
#[cfg(not(target_arch = "wasm32"))]
|
||
let (ws_conn, _) = ws_conn_maybe_response;
|
||
|
||
// For WASM, WebSocket::open() always succeeds even if server is down
|
||
// We'll start as "connecting" and detect failures through timeouts
|
||
#[cfg(target_arch = "wasm32")]
|
||
info!("🔄 [{}] WebSocket object created, testing actual connectivity...", log_url);
|
||
#[cfg(not(target_arch = "wasm32"))]
|
||
{
|
||
info!("✅ [{}] WebSocket connection established successfully", log_url);
|
||
*is_connected.lock().unwrap() = true;
|
||
}
|
||
|
||
// Handle authentication if private key is provided
|
||
let auth_success = if let Some(ref _pk) = private_key {
|
||
info!("🔐 [{}] Authentication will be handled by separate authenticate() call", log_url);
|
||
true // For now, assume auth will be handled separately
|
||
} else {
|
||
info!("ℹ️ [{}] No private key provided, skipping authentication", log_url);
|
||
true
|
||
};
|
||
|
||
if auth_success {
|
||
// Start the main message handling loop with keep-alive
|
||
let disconnect_reason = Self::handle_connection_with_keepalive(
|
||
ws_conn,
|
||
internal_rx,
|
||
&task_pending_requests,
|
||
&log_url,
|
||
&is_connected
|
||
).await;
|
||
|
||
info!("🔌 [{}] Connection ended: {}", log_url, disconnect_reason);
|
||
|
||
// Check if this was a manual disconnect
|
||
if disconnect_reason == "Manual close requested" {
|
||
break; // Don't reconnect on manual close
|
||
}
|
||
|
||
// If we reach here, we need to recreate internal_rx for the next iteration
|
||
// But since internal_rx was moved, we need to break out of the loop
|
||
break;
|
||
}
|
||
}
|
||
Err(e) => {
|
||
error!("❌ [{}] WebSocket connection failed: {:?}", log_url, e);
|
||
}
|
||
}
|
||
|
||
// Reset connection status
|
||
*is_connected.lock().unwrap() = false;
|
||
|
||
// Wait before reconnecting
|
||
info!("⏳ [{}] Waiting 5 seconds before reconnection attempt...", log_url);
|
||
#[cfg(target_arch = "wasm32")]
|
||
{
|
||
use gloo_timers::future::TimeoutFuture;
|
||
TimeoutFuture::new(5_000).await;
|
||
}
|
||
#[cfg(not(target_arch = "wasm32"))]
|
||
{
|
||
tokio::time::sleep(std::time::Duration::from_secs(5)).await;
|
||
}
|
||
}
|
||
|
||
// Cleanup pending requests on exit
|
||
task_pending_requests
|
||
.lock()
|
||
.unwrap()
|
||
.drain()
|
||
.for_each(|(_, sender)| {
|
||
let _ = sender.send(Err(CircleWsClientError::ConnectionError(
|
||
"WebSocket task terminated".to_string(),
|
||
)));
|
||
});
|
||
|
||
info!("🏁 [{}] WebSocket task finished", log_url);
|
||
};
|
||
|
||
#[cfg(target_arch = "wasm32")]
|
||
spawn_local(task);
|
||
#[cfg(not(target_arch = "wasm32"))]
|
||
{
|
||
self.task_handle = Some(spawn_local(task));
|
||
}
|
||
|
||
Ok(())
|
||
}
|
||
|
||
// Enhanced connection loop handler with keep-alive
|
||
#[cfg(target_arch = "wasm32")]
|
||
async fn handle_connection_with_keepalive(
|
||
ws_conn: WebSocket,
|
||
mut internal_rx: mpsc::Receiver<InternalWsMessage>,
|
||
pending_requests: &Arc<Mutex<HashMap<String, oneshot::Sender<Result<JsonRpcResponseClient, CircleWsClientError>>>>>,
|
||
log_url: &str,
|
||
is_connected: &Arc<Mutex<bool>>,
|
||
) -> String {
|
||
let (mut ws_tx, mut ws_rx) = ws_conn.split();
|
||
let mut internal_rx_fused = internal_rx.fuse();
|
||
|
||
// Track plaintext requests (like ping)
|
||
let pending_plaintext: Arc<Mutex<HashMap<String, oneshot::Sender<Result<String, CircleWsClientError>>>>> = Arc::new(Mutex::new(HashMap::new()));
|
||
|
||
// Connection validation for WASM - test if connection actually works
|
||
let mut connection_test_timer = TimeoutFuture::new(2_000).fuse(); // 2 second timeout
|
||
let mut connection_validated = false;
|
||
|
||
// Keep-alive timer - send ping every 30 seconds
|
||
use gloo_timers::future::TimeoutFuture;
|
||
let mut keep_alive_timer = TimeoutFuture::new(30_000).fuse();
|
||
|
||
// Send initial connection test ping
|
||
debug!("Sending initial connection test ping to {}", log_url);
|
||
let test_ping_res = ws_tx.send(GlooWsMessage::Text("ping".to_string())).await;
|
||
if let Err(e) = test_ping_res {
|
||
error!("❌ [{}] Initial connection test failed: {:?}", log_url, e);
|
||
*is_connected.lock().unwrap() = false;
|
||
return format!("Initial connection test failed: {}", e);
|
||
}
|
||
|
||
loop {
|
||
futures_util::select! {
|
||
// Connection test timeout - if no response in 2 seconds, connection failed
|
||
_ = connection_test_timer => {
|
||
if !connection_validated {
|
||
error!("❌ [{}] Connection test failed - no response within 2 seconds", log_url);
|
||
*is_connected.lock().unwrap() = false;
|
||
return "Connection test timeout - server not responding".to_string();
|
||
}
|
||
}
|
||
|
||
// Handle messages from the client's public methods (e.g., play)
|
||
internal_msg = internal_rx_fused.next().fuse() => {
|
||
match internal_msg {
|
||
Some(InternalWsMessage::SendJsonRpc(req, response_sender)) => {
|
||
let req_id = req.id.clone();
|
||
match serde_json::to_string(&req) {
|
||
Ok(req_str) => {
|
||
debug!("Sending JSON-RPC request (ID: {}): {}", req_id, req_str);
|
||
let send_res = ws_tx.send(GlooWsMessage::Text(req_str)).await;
|
||
if let Err(e) = send_res {
|
||
error!("WebSocket send error for request ID {}: {:?}", req_id, e);
|
||
// Connection failed - update status
|
||
*is_connected.lock().unwrap() = false;
|
||
let _ = response_sender.send(Err(CircleWsClientError::SendError(e.to_string())));
|
||
} else {
|
||
// Store the sender to await the response
|
||
pending_requests.lock().unwrap().insert(req_id, response_sender);
|
||
}
|
||
}
|
||
Err(e) => {
|
||
error!("Failed to serialize request ID {}: {}", req_id, e);
|
||
let _ = response_sender.send(Err(CircleWsClientError::JsonError(e)));
|
||
}
|
||
}
|
||
}
|
||
Some(InternalWsMessage::SendPlaintext(text, response_sender)) => {
|
||
debug!("Sending plaintext message: {}", text);
|
||
let send_res = ws_tx.send(GlooWsMessage::Text(text.clone())).await;
|
||
if let Err(e) = send_res {
|
||
error!("WebSocket send error for plaintext message: {:?}", e);
|
||
*is_connected.lock().unwrap() = false;
|
||
let _ = response_sender.send(Err(CircleWsClientError::SendError(e.to_string())));
|
||
} else {
|
||
// For plaintext messages like ping, we expect an immediate response
|
||
// Store the response sender to await the response (e.g., pong)
|
||
let request_id = format!("plaintext_{}", uuid::Uuid::new_v4());
|
||
pending_plaintext.lock().unwrap().insert(request_id, response_sender);
|
||
}
|
||
}
|
||
Some(InternalWsMessage::Close) => {
|
||
info!("Close message received internally, closing WebSocket.");
|
||
let _ = ws_tx.close().await;
|
||
return "Manual close requested".to_string();
|
||
}
|
||
None => {
|
||
info!("Internal MPSC channel closed, WebSocket task shutting down.");
|
||
let _ = ws_tx.close().await;
|
||
return "Internal channel closed".to_string();
|
||
}
|
||
}
|
||
},
|
||
|
||
// Handle messages received from the WebSocket server
|
||
ws_msg_res = ws_rx.next().fuse() => {
|
||
match ws_msg_res {
|
||
Some(Ok(msg)) => {
|
||
// Any successful message confirms the connection is working
|
||
if !connection_validated {
|
||
info!("✅ [{}] WebSocket connection validated - received message from server", log_url);
|
||
*is_connected.lock().unwrap() = true;
|
||
connection_validated = true;
|
||
}
|
||
|
||
match msg {
|
||
GlooWsMessage::Text(text) => {
|
||
debug!("Received WebSocket message: {}", text);
|
||
Self::handle_received_message(&text, pending_requests, &pending_plaintext);
|
||
}
|
||
GlooWsMessage::Bytes(_) => {
|
||
debug!("Received binary WebSocket message (WASM).");
|
||
}
|
||
}
|
||
}
|
||
Some(Err(e)) => {
|
||
error!("WebSocket receive error: {:?}", e);
|
||
*is_connected.lock().unwrap() = false;
|
||
return format!("Receive error: {}", e);
|
||
}
|
||
None => {
|
||
info!("WebSocket connection closed by server (stream ended).");
|
||
*is_connected.lock().unwrap() = false;
|
||
return "Server closed connection (stream ended)".to_string();
|
||
}
|
||
}
|
||
}
|
||
|
||
// Keep-alive timer - send ping every 30 seconds
|
||
_ = keep_alive_timer => {
|
||
// Only send ping if connection is validated
|
||
if connection_validated {
|
||
debug!("Sending keep-alive ping to {}", log_url);
|
||
let ping_str = "ping"; // Send simple plaintext ping
|
||
|
||
let send_res = ws_tx.send(GlooWsMessage::Text(ping_str.to_string())).await;
|
||
if let Err(e) = send_res {
|
||
warn!("Keep-alive ping failed for {}: {:?}", log_url, e);
|
||
*is_connected.lock().unwrap() = false;
|
||
return format!("Keep-alive failed: {}", e);
|
||
}
|
||
} else {
|
||
debug!("Skipping keep-alive ping - connection not yet validated for {}", log_url);
|
||
}
|
||
|
||
// Reset timer
|
||
keep_alive_timer = TimeoutFuture::new(30_000).fuse();
|
||
}
|
||
}
|
||
}
|
||
}
|
||
|
||
// Enhanced connection loop handler with keep-alive for native targets
|
||
#[cfg(not(target_arch = "wasm32"))]
|
||
async fn handle_connection_with_keepalive(
|
||
ws_conn: tokio_tungstenite::WebSocketStream<tokio_tungstenite::MaybeTlsStream<tokio::net::TcpStream>>,
|
||
mut internal_rx: mpsc::Receiver<InternalWsMessage>,
|
||
pending_requests: &Arc<Mutex<HashMap<String, oneshot::Sender<Result<JsonRpcResponseClient, CircleWsClientError>>>>>,
|
||
log_url: &str,
|
||
_is_connected: &Arc<Mutex<bool>>,
|
||
) -> String {
|
||
let (mut ws_tx, mut ws_rx) = ws_conn.split();
|
||
let mut internal_rx_fused = internal_rx.fuse();
|
||
|
||
// Track plaintext requests (like ping)
|
||
let pending_plaintext: Arc<Mutex<HashMap<String, oneshot::Sender<Result<String, CircleWsClientError>>>>> = Arc::new(Mutex::new(HashMap::new()));
|
||
|
||
loop {
|
||
futures_util::select! {
|
||
// Handle messages from the client's public methods (e.g., play)
|
||
internal_msg = internal_rx_fused.next().fuse() => {
|
||
match internal_msg {
|
||
Some(InternalWsMessage::SendJsonRpc(req, response_sender)) => {
|
||
let req_id = req.id.clone();
|
||
match serde_json::to_string(&req) {
|
||
Ok(req_str) => {
|
||
debug!("Sending JSON-RPC request (ID: {}): {}", req_id, req_str);
|
||
let send_res = ws_tx.send(TungsteniteWsMessage::Text(req_str)).await;
|
||
if let Err(e) = send_res {
|
||
error!("WebSocket send error for request ID {}: {:?}", req_id, e);
|
||
let _ = response_sender.send(Err(CircleWsClientError::SendError(e.to_string())));
|
||
} else {
|
||
// Store the sender to await the response
|
||
pending_requests.lock().unwrap().insert(req_id, response_sender);
|
||
}
|
||
}
|
||
Err(e) => {
|
||
error!("Failed to serialize request ID {}: {}", req_id, e);
|
||
let _ = response_sender.send(Err(CircleWsClientError::JsonError(e)));
|
||
}
|
||
}
|
||
}
|
||
Some(InternalWsMessage::SendPlaintext(text, response_sender)) => {
|
||
debug!("Sending plaintext message: {}", text);
|
||
let send_res = ws_tx.send(TungsteniteWsMessage::Text(text.clone())).await;
|
||
if let Err(e) = send_res {
|
||
error!("WebSocket send error for plaintext message: {:?}", e);
|
||
let _ = response_sender.send(Err(CircleWsClientError::SendError(e.to_string())));
|
||
} else {
|
||
// For plaintext messages like ping, we expect an immediate response
|
||
// Store the response sender to await the response (e.g., pong)
|
||
let request_id = format!("plaintext_{}", uuid::Uuid::new_v4());
|
||
pending_plaintext.lock().unwrap().insert(request_id, response_sender);
|
||
}
|
||
}
|
||
Some(InternalWsMessage::Close) => {
|
||
info!("Close message received internally, closing WebSocket.");
|
||
let _ = ws_tx.close().await;
|
||
return "Manual close requested".to_string();
|
||
}
|
||
None => {
|
||
info!("Internal MPSC channel closed, WebSocket task shutting down.");
|
||
let _ = ws_tx.close().await;
|
||
return "Internal channel closed".to_string();
|
||
}
|
||
}
|
||
},
|
||
|
||
// Handle messages received from the WebSocket server
|
||
ws_msg_res = ws_rx.next().fuse() => {
|
||
match ws_msg_res {
|
||
Some(Ok(msg)) => {
|
||
match msg {
|
||
TungsteniteWsMessage::Text(text) => {
|
||
debug!("Received WebSocket message: {}", text);
|
||
Self::handle_received_message(&text, pending_requests, &pending_plaintext);
|
||
}
|
||
TungsteniteWsMessage::Binary(_) => {
|
||
debug!("Received binary WebSocket message (Native).");
|
||
}
|
||
TungsteniteWsMessage::Ping(_) | TungsteniteWsMessage::Pong(_) => {
|
||
debug!("Received Ping/Pong (Native).");
|
||
}
|
||
TungsteniteWsMessage::Close(_) => {
|
||
info!("WebSocket connection closed by server (Native).");
|
||
return "Server closed connection".to_string();
|
||
}
|
||
TungsteniteWsMessage::Frame(_) => {
|
||
debug!("Received Frame (Native) - not typically handled directly.");
|
||
}
|
||
}
|
||
}
|
||
Some(Err(e)) => {
|
||
error!("WebSocket receive error: {:?}", e);
|
||
return format!("Receive error: {}", e);
|
||
}
|
||
None => {
|
||
info!("WebSocket connection closed by server (stream ended).");
|
||
return "Server closed connection (stream ended)".to_string();
|
||
}
|
||
}
|
||
}
|
||
}
|
||
}
|
||
}
|
||
|
||
// Helper method to handle received messages
|
||
fn handle_received_message(
|
||
text: &str,
|
||
pending_requests: &Arc<Mutex<HashMap<String, oneshot::Sender<Result<JsonRpcResponseClient, CircleWsClientError>>>>>,
|
||
pending_plaintext: &Arc<Mutex<HashMap<String, oneshot::Sender<Result<String, CircleWsClientError>>>>>,
|
||
) {
|
||
// Handle ping/pong messages - these are not JSON-RPC
|
||
if text.trim() == "pong" {
|
||
debug!("Received pong response");
|
||
// Find and respond to any pending plaintext ping requests
|
||
let mut plaintext_map = pending_plaintext.lock().unwrap();
|
||
if let Some((_, sender)) = plaintext_map.drain().next() {
|
||
let _ = sender.send(Ok("pong".to_string()));
|
||
}
|
||
return;
|
||
}
|
||
|
||
match serde_json::from_str::<JsonRpcResponseClient>(text) {
|
||
Ok(response) => {
|
||
if let Some(sender) = pending_requests.lock().unwrap().remove(&response.id) {
|
||
if let Err(failed_send_val) = sender.send(Ok(response)) {
|
||
if let Ok(resp_for_log) = failed_send_val {
|
||
warn!("Failed to send response to waiting task for ID: {}", resp_for_log.id);
|
||
} else {
|
||
warn!("Failed to send response to waiting task, and also failed to get original response for logging.");
|
||
}
|
||
}
|
||
} else {
|
||
warn!("Received response for unknown request ID or unsolicited message: {:?}", response);
|
||
}
|
||
}
|
||
Err(e) => {
|
||
error!("Failed to parse JSON-RPC response: {}. Raw: {}", e, text);
|
||
}
|
||
}
|
||
}
|
||
|
||
pub fn play(
|
||
&self,
|
||
script: String,
|
||
) -> impl std::future::Future<Output = Result<PlayResultClient, CircleWsClientError>> + Send + 'static
|
||
{
|
||
let req_id_outer = Uuid::new_v4().to_string();
|
||
|
||
// Clone the sender option. The sender itself (mpsc::Sender) is also Clone.
|
||
let internal_tx_clone_opt = self.internal_tx.clone();
|
||
|
||
async move {
|
||
let req_id = req_id_outer; // Move req_id into the async block
|
||
let params = PlayParamsClient { script }; // script is moved in
|
||
|
||
let request = match serde_json::to_value(params) {
|
||
Ok(p_val) => JsonRpcRequestClient {
|
||
jsonrpc: "2.0".to_string(),
|
||
method: "play".to_string(),
|
||
params: p_val,
|
||
id: req_id.clone(),
|
||
},
|
||
Err(e) => return Err(CircleWsClientError::JsonError(e)),
|
||
};
|
||
|
||
let (response_tx, response_rx) = oneshot::channel();
|
||
|
||
if let Some(mut internal_tx) = internal_tx_clone_opt {
|
||
internal_tx
|
||
.send(InternalWsMessage::SendJsonRpc(request, response_tx))
|
||
.await
|
||
.map_err(|e| {
|
||
CircleWsClientError::ChannelError(format!(
|
||
"Failed to send request to internal task: {}",
|
||
e
|
||
))
|
||
})?;
|
||
} else {
|
||
return Err(CircleWsClientError::NotConnected);
|
||
}
|
||
|
||
// Add a timeout for waiting for the response
|
||
// For simplicity, using a fixed timeout here. Could be configurable.
|
||
#[cfg(target_arch = "wasm32")]
|
||
{
|
||
match response_rx.await {
|
||
Ok(Ok(rpc_response)) => {
|
||
if let Some(json_rpc_error) = rpc_response.error {
|
||
Err(CircleWsClientError::JsonRpcError {
|
||
code: json_rpc_error.code,
|
||
message: json_rpc_error.message,
|
||
data: json_rpc_error.data,
|
||
})
|
||
} else if let Some(result_value) = rpc_response.result {
|
||
serde_json::from_value(result_value)
|
||
.map_err(CircleWsClientError::JsonError)
|
||
} else {
|
||
Err(CircleWsClientError::NoResponse(req_id.clone()))
|
||
}
|
||
}
|
||
Ok(Err(e)) => Err(e), // Error propagated from the ws task
|
||
Err(_) => Err(CircleWsClientError::Timeout(req_id.clone())), // oneshot channel cancelled
|
||
}
|
||
}
|
||
#[cfg(not(target_arch = "wasm32"))]
|
||
{
|
||
use tokio::time::timeout as tokio_timeout;
|
||
match tokio_timeout(std::time::Duration::from_secs(10), response_rx).await {
|
||
Ok(Ok(Ok(rpc_response))) => {
|
||
// Timeout -> Result<ChannelRecvResult, Error>
|
||
if let Some(json_rpc_error) = rpc_response.error {
|
||
Err(CircleWsClientError::JsonRpcError {
|
||
code: json_rpc_error.code,
|
||
message: json_rpc_error.message,
|
||
data: json_rpc_error.data,
|
||
})
|
||
} else if let Some(result_value) = rpc_response.result {
|
||
serde_json::from_value(result_value)
|
||
.map_err(CircleWsClientError::JsonError)
|
||
} else {
|
||
Err(CircleWsClientError::NoResponse(req_id.clone()))
|
||
}
|
||
}
|
||
Ok(Ok(Err(e))) => Err(e), // Error propagated from the ws task
|
||
Ok(Err(_)) => Err(CircleWsClientError::ChannelError(
|
||
"Response channel cancelled".to_string(),
|
||
)), // oneshot cancelled
|
||
Err(_) => Err(CircleWsClientError::Timeout(req_id.clone())), // tokio_timeout expired
|
||
}
|
||
}
|
||
}
|
||
}
|
||
|
||
/// Send a plaintext ping message and wait for pong response
|
||
pub async fn ping(&mut self) -> Result<String, CircleWsClientError> {
|
||
if let Some(mut tx) = self.internal_tx.clone() {
|
||
let (response_tx, response_rx) = oneshot::channel();
|
||
|
||
// Send plaintext ping message
|
||
tx.send(InternalWsMessage::SendPlaintext("ping".to_string(), response_tx))
|
||
.await
|
||
.map_err(|e| {
|
||
CircleWsClientError::ChannelError(format!(
|
||
"Failed to send ping request to internal task: {}",
|
||
e
|
||
))
|
||
})?;
|
||
|
||
// Wait for pong response with timeout
|
||
#[cfg(target_arch = "wasm32")]
|
||
{
|
||
match response_rx.await {
|
||
Ok(Ok(response)) => Ok(response),
|
||
Ok(Err(e)) => Err(e),
|
||
Err(_) => Err(CircleWsClientError::ChannelError(
|
||
"Ping response channel cancelled".to_string(),
|
||
)),
|
||
}
|
||
}
|
||
#[cfg(not(target_arch = "wasm32"))]
|
||
{
|
||
use tokio::time::timeout as tokio_timeout;
|
||
match tokio_timeout(std::time::Duration::from_secs(10), response_rx).await {
|
||
Ok(Ok(Ok(response))) => Ok(response),
|
||
Ok(Ok(Err(e))) => Err(e),
|
||
Ok(Err(_)) => Err(CircleWsClientError::ChannelError(
|
||
"Ping response channel cancelled".to_string(),
|
||
)),
|
||
Err(_) => Err(CircleWsClientError::Timeout("ping".to_string())),
|
||
}
|
||
}
|
||
} else {
|
||
Err(CircleWsClientError::NotConnected)
|
||
}
|
||
}
|
||
|
||
pub async fn disconnect(&mut self) {
|
||
if let Some(mut tx) = self.internal_tx.take() {
|
||
info!("Sending close signal to internal WebSocket task.");
|
||
let _ = tx.send(InternalWsMessage::Close).await;
|
||
}
|
||
#[cfg(not(target_arch = "wasm32"))]
|
||
if let Some(handle) = self.task_handle.take() {
|
||
let _ = handle.await; // Wait for the task to finish
|
||
}
|
||
info!("Client disconnected.");
|
||
}
|
||
}
|
||
|
||
// Ensure client cleans up on drop for native targets
|
||
#[cfg(not(target_arch = "wasm32"))]
|
||
impl Drop for CircleWsClient {
|
||
fn drop(&mut self) {
|
||
if self.internal_tx.is_some() || self.task_handle.is_some() {
|
||
warn!("CircleWsClient dropped without explicit disconnect. Spawning task to send close signal.");
|
||
// We can't call async disconnect directly in drop.
|
||
// Spawn a new task to send the close message if on native.
|
||
if let Some(mut tx) = self.internal_tx.take() {
|
||
spawn_local(async move {
|
||
info!("Drop: Sending close signal to internal WebSocket task.");
|
||
let _ = tx.send(InternalWsMessage::Close).await;
|
||
});
|
||
}
|
||
if let Some(handle) = self.task_handle.take() {
|
||
spawn_local(async move {
|
||
info!("Drop: Waiting for WebSocket task to finish.");
|
||
let _ = handle.await;
|
||
info!("Drop: WebSocket task finished.");
|
||
});
|
||
}
|
||
}
|
||
}
|
||
}
|
||
|
||
#[cfg(test)]
|
||
mod tests {
|
||
// use super::*;
|
||
#[test]
|
||
fn it_compiles() {
|
||
assert_eq!(2 + 2, 4);
|
||
}
|
||
}
|