315 lines
15 KiB
Rust
315 lines
15 KiB
Rust
use std::time::{Duration, Instant};
|
|
use std::sync::{Arc, RwLock};
|
|
use std::collections::HashMap;
|
|
use actix::prelude::*;
|
|
use actix_web_actors::ws;
|
|
use crate::protocol::SignRequest;
|
|
use crate::registry::ConnectionRegistry;
|
|
use crate::crypto::SignatureVerifier;
|
|
use uuid::Uuid;
|
|
use log::{info, warn, error};
|
|
use sha2::{Sha256, Digest};
|
|
|
|
// Heartbeat functionality has been removed
|
|
|
|
/// WebSocket connection manager for handling signing operations
|
|
pub struct SigSocketManager {
|
|
/// Registry of connections
|
|
pub registry: Arc<RwLock<ConnectionRegistry>>,
|
|
/// Public key of the connection
|
|
pub public_key: Option<String>,
|
|
/// Pending requests with their response channels
|
|
pub pending_requests: HashMap<String, tokio::sync::oneshot::Sender<String>>,
|
|
}
|
|
|
|
impl SigSocketManager {
|
|
pub fn new(registry: Arc<RwLock<ConnectionRegistry>>) -> Self {
|
|
Self {
|
|
registry,
|
|
public_key: None,
|
|
pending_requests: HashMap::new(),
|
|
}
|
|
}
|
|
|
|
// Heartbeat functionality has been removed
|
|
|
|
/// Helper method to extract request ID from a message
|
|
fn extract_request_id(&self, message: &str) -> Option<String> {
|
|
// The client sends the original base64 message, which is the request ID directly
|
|
// But try to be robust in case the format changes
|
|
|
|
// First try to handle the case where the message is exactly the request ID
|
|
if message.len() >= 8 && message.contains('-') {
|
|
// This looks like it might be a UUID directly
|
|
return Some(message.to_string());
|
|
}
|
|
|
|
// Next try to parse as JSON (in case we get a JSON structure)
|
|
if let Ok(parsed) = serde_json::from_str::<serde_json::Value>(message) {
|
|
if let Some(id) = parsed.get("id").and_then(|v| v.as_str()) {
|
|
return Some(id.to_string());
|
|
}
|
|
}
|
|
|
|
// Finally, just treat the entire message as the key
|
|
// This is a fallback and may not find a match
|
|
info!("Using full message as request ID fallback: {}", message);
|
|
Some(message.to_string())
|
|
}
|
|
|
|
/// Process messages received over the websocket
|
|
fn handle_text_message(&mut self, text: String, ctx: &mut ws::WebsocketContext<Self>) {
|
|
// If this is the first message and we don't have a public key yet, treat it as an introduction
|
|
if self.public_key.is_none() {
|
|
// Validate the public key format
|
|
match hex::decode(&text) {
|
|
Ok(pk_bytes) => {
|
|
// Further validate with secp256k1
|
|
match secp256k1::PublicKey::from_slice(&pk_bytes) {
|
|
Ok(_) => {
|
|
// This is a valid public key, register it
|
|
info!("Registered connection for public key: {}", text);
|
|
self.public_key = Some(text.clone());
|
|
|
|
// Register in the connection registry
|
|
if let Ok(mut registry) = self.registry.write() {
|
|
registry.register(text.clone(), ctx.address());
|
|
}
|
|
|
|
// Acknowledge
|
|
ctx.text("Connected");
|
|
}
|
|
Err(_) => {
|
|
warn!("Invalid secp256k1 public key format: {}", text);
|
|
ctx.text("Invalid public key format - must be valid secp256k1");
|
|
ctx.close(Some(ws::CloseReason {
|
|
code: ws::CloseCode::Invalid,
|
|
description: Some("Invalid public key format".into()),
|
|
}));
|
|
}
|
|
}
|
|
}
|
|
Err(e) => {
|
|
error!("Invalid hex format for public key: {}", e);
|
|
ctx.text("Invalid public key format - must be hex encoded");
|
|
ctx.close(Some(ws::CloseReason {
|
|
code: ws::CloseCode::Invalid,
|
|
description: Some("Invalid public key format".into()),
|
|
}));
|
|
}
|
|
}
|
|
return;
|
|
}
|
|
|
|
// If we have a public key, this is either a response to a signing request
|
|
// New Format: JSON with id, message, signature fields
|
|
info!("Received message from client with public key: {}", self.public_key.as_ref().unwrap_or(&"<NONE>".to_string()));
|
|
info!("Raw message content: {}", text);
|
|
|
|
// Special case for confirmation message
|
|
if text == "CONFIRM_SIGNATURE_SENT" {
|
|
info!("Received confirmation message after signature");
|
|
return;
|
|
}
|
|
|
|
// Try to parse the message as JSON
|
|
match serde_json::from_str::<serde_json::Value>(&text) {
|
|
Ok(json) => {
|
|
info!("Successfully parsed message as JSON");
|
|
|
|
// Extract fields from the JSON response
|
|
let request_id = json.get("id").and_then(|v| v.as_str());
|
|
let message_b64 = json.get("message").and_then(|v| v.as_str());
|
|
let signature_b64 = json.get("signature").and_then(|v| v.as_str());
|
|
|
|
match (request_id, message_b64, signature_b64) {
|
|
(Some(id), Some(message), Some(signature)) => {
|
|
info!("Extracted request ID: {}", id);
|
|
info!("Parsed message part (base64): {}", message);
|
|
info!("Parsed signature part (base64): {}", signature);
|
|
|
|
// Try to decode both parts
|
|
info!("Attempting to decode base64 message and signature");
|
|
match (
|
|
base64::Engine::decode(&base64::engine::general_purpose::STANDARD, message),
|
|
base64::Engine::decode(&base64::engine::general_purpose::STANDARD, signature),
|
|
) {
|
|
(Ok(message), Ok(signature)) => {
|
|
info!("Successfully decoded message and signature");
|
|
info!("Message bytes (decoded): {:?}", message);
|
|
info!("Signature bytes (length): {} bytes", signature.len());
|
|
|
|
// Calculate the message hash (this is implementation specific)
|
|
let mut hasher = Sha256::new();
|
|
hasher.update(&message);
|
|
let message_hash = hasher.finalize();
|
|
info!("Calculated message hash: {:?}", message_hash);
|
|
|
|
// Verify the signature with the public key
|
|
if let Some(ref public_key) = self.public_key {
|
|
info!("Using public key for verification: {}", public_key);
|
|
let sig_hex = hex::encode(&signature);
|
|
info!("Signature (hex): {}", sig_hex);
|
|
|
|
info!("!!! ATTEMPTING SIGNATURE VERIFICATION !!!");
|
|
match SignatureVerifier::verify_signature(
|
|
public_key,
|
|
&message,
|
|
&sig_hex,
|
|
) {
|
|
Ok(true) => {
|
|
info!("!!! SIGNATURE VERIFICATION SUCCESSFUL !!!");
|
|
|
|
// We already have the request ID from the JSON!
|
|
info!("Using request ID directly from JSON: {}", id);
|
|
|
|
// Find and complete the pending request using the ID from the JSON
|
|
if let Some(sender) = self.pending_requests.remove(id) {
|
|
info!("Found pending request with ID: {}", id);
|
|
|
|
// Format the message and signature for the receiver
|
|
// Use base64 for BOTH message and signature as per the protocol requirements
|
|
let response = format!("{}.{}",
|
|
base64::Engine::encode(&base64::engine::general_purpose::STANDARD, &message),
|
|
base64::Engine::encode(&base64::engine::general_purpose::STANDARD, &signature));
|
|
|
|
info!("Formatted response: {} (truncated for log)",
|
|
if response.len() > 50 { &response[..50] } else { &response });
|
|
|
|
// Send the response directly using the stored channel
|
|
info!("Sending signature via direct response channel");
|
|
if sender.send(response).is_err() {
|
|
error!("Failed to send signature via response channel for request {}", id);
|
|
} else {
|
|
info!("!!! SUCCESSFULLY SENT SIGNATURE VIA RESPONSE CHANNEL FOR REQUEST {} !!!", id);
|
|
}
|
|
} else {
|
|
error!("No pending request found with ID: {}", id);
|
|
info!("Current pending requests: {:?}", self.pending_requests.keys().collect::<Vec<_>>());
|
|
}
|
|
},
|
|
Ok(false) => {
|
|
warn!("!!! SIGNATURE VERIFICATION FAILED - INVALID SIGNATURE !!!");
|
|
ctx.text("Invalid signature");
|
|
},
|
|
Err(e) => {
|
|
error!("!!! SIGNATURE VERIFICATION ERROR: {} !!!", e);
|
|
ctx.text("Error verifying signature");
|
|
}
|
|
}
|
|
} else {
|
|
error!("Missing public key for verification");
|
|
ctx.text("Missing public key for verification");
|
|
}
|
|
},
|
|
(Err(e1), _) => {
|
|
warn!("Failed to decode base64 message: {}", e1);
|
|
ctx.text("Invalid base64 encoding in message");
|
|
},
|
|
(_, Err(e2)) => {
|
|
warn!("Failed to decode base64 signature: {}", e2);
|
|
ctx.text("Invalid base64 encoding in signature");
|
|
}
|
|
}
|
|
},
|
|
_ => {
|
|
warn!("Missing required fields in JSON response");
|
|
ctx.text("Missing required fields in JSON response");
|
|
}
|
|
}
|
|
},
|
|
Err(e) => {
|
|
warn!("Received message in invalid JSON format: {} - {}", text, e);
|
|
ctx.text("Invalid JSON format");
|
|
}
|
|
}
|
|
}
|
|
}
|
|
|
|
/// Handler for SignRequest message
|
|
impl Handler<SignRequest> for SigSocketManager {
|
|
type Result = ();
|
|
|
|
fn handle(&mut self, msg: SignRequest, ctx: &mut Self::Context) {
|
|
// We'll only process sign requests if we have a valid public key
|
|
if self.public_key.is_none() {
|
|
error!("Received sign request for connection without a public key");
|
|
return;
|
|
}
|
|
|
|
// Debug log the current pending requests in the manager
|
|
info!("*** MANAGER: Current pending requests before handling sign request: {:?} ***",
|
|
self.pending_requests.keys().collect::<Vec<_>>());
|
|
|
|
// If we received a response sender, store it for later
|
|
if let Some(sender) = msg.response_sender {
|
|
// Store the request ID and sender in our pending requests map
|
|
self.pending_requests.insert(msg.request_id.clone(), sender);
|
|
|
|
info!("*** MANAGER: Added pending request with response channel: {} ***", msg.request_id);
|
|
info!("*** MANAGER: Current pending requests after adding: {:?} ***",
|
|
self.pending_requests.keys().collect::<Vec<_>>());
|
|
} else {
|
|
warn!("Received SignRequest without response channel for ID: {}", msg.request_id);
|
|
}
|
|
|
|
// Create JSON message to send to the client
|
|
let message_b64 = base64::Engine::encode(&base64::engine::general_purpose::STANDARD, &msg.message);
|
|
let request_json = format!("{{\"id\": \"{}\", \"message\": \"{}\"}}",
|
|
msg.request_id, message_b64);
|
|
|
|
// Send the request to the client
|
|
ctx.text(request_json);
|
|
|
|
info!("Sent sign request {} to client {}", msg.request_id, self.public_key.as_ref().unwrap());
|
|
}
|
|
}
|
|
|
|
/// Handler for WebSocket messages
|
|
impl StreamHandler<Result<ws::Message, ws::ProtocolError>> for SigSocketManager {
|
|
fn handle(&mut self, msg: Result<ws::Message, ws::ProtocolError>, ctx: &mut Self::Context) {
|
|
match msg {
|
|
Ok(ws::Message::Ping(msg)) => {
|
|
// Simply respond to ping with pong - no heartbeat tracking
|
|
ctx.pong(&msg);
|
|
}
|
|
Ok(ws::Message::Pong(_)) => {
|
|
// No need to track heartbeat anymore
|
|
}
|
|
Ok(ws::Message::Text(text)) => {
|
|
self.handle_text_message(text.to_string(), ctx);
|
|
}
|
|
Ok(ws::Message::Binary(_)) => {
|
|
// We don't expect binary messages in this protocol
|
|
warn!("Unexpected binary message received");
|
|
}
|
|
Ok(ws::Message::Close(reason)) => {
|
|
info!("Client disconnected");
|
|
ctx.close(reason);
|
|
ctx.stop();
|
|
}
|
|
_ => ctx.stop(),
|
|
}
|
|
}
|
|
}
|
|
|
|
impl Actor for SigSocketManager {
|
|
type Context = ws::WebsocketContext<Self>;
|
|
|
|
fn started(&mut self, _ctx: &mut Self::Context) {
|
|
// Heartbeat functionality has been removed
|
|
info!("WebSocket connection established");
|
|
}
|
|
|
|
fn stopped(&mut self, _ctx: &mut Self::Context) {
|
|
// Unregister from the registry if we have a public key
|
|
if let Some(ref pk) = self.public_key {
|
|
info!("WebSocket connection closed for {}", pk);
|
|
|
|
if let Ok(mut registry) = self.registry.write() {
|
|
registry.unregister(pk);
|
|
}
|
|
}
|
|
}
|
|
}
|