feat: Enhance request management in SigSocket client with new methods and structures

This commit is contained in:
Sameh Abouel-saad
2025-06-05 20:02:34 +03:00
parent 580fd72dce
commit 6b037537bf
6 changed files with 806 additions and 8 deletions

View File

@@ -1,9 +1,18 @@
//! Main client interface for sigsocket communication
#[cfg(target_arch = "wasm32")]
use alloc::{string::String, vec::Vec, boxed::Box};
use alloc::{string::String, vec::Vec, boxed::Box, string::ToString};
#[cfg(not(target_arch = "wasm32"))]
use std::collections::HashMap;
#[cfg(target_arch = "wasm32")]
use alloc::collections::BTreeMap as HashMap;
use crate::{SignRequest, SignResponse, Result, SigSocketError};
use crate::protocol::ManagedSignRequest;
/// Connection state of the sigsocket client
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
@@ -67,6 +76,10 @@ pub struct SigSocketClient {
state: ConnectionState,
/// Sign request handler
sign_handler: Option<Box<dyn SignRequestHandler>>,
/// Pending sign requests managed by the client
pending_requests: HashMap<String, ManagedSignRequest>,
/// Connected public key (hex-encoded) - set when connection is established
connected_public_key: Option<String>,
/// Platform-specific implementation
#[cfg(not(target_arch = "wasm32"))]
inner: Option<crate::native::NativeClient>,
@@ -100,14 +113,16 @@ impl SigSocketClient {
public_key,
state: ConnectionState::Disconnected,
sign_handler: None,
pending_requests: HashMap::new(),
connected_public_key: None,
inner: None,
})
}
/// Set the sign request handler
///
///
/// This handler will be called whenever the server sends a signature request.
///
///
/// # Arguments
/// * `handler` - Implementation of SignRequestHandler trait
pub fn set_sign_handler<H>(&mut self, handler: H)
@@ -117,6 +132,8 @@ impl SigSocketClient {
self.sign_handler = Some(Box::new(handler));
}
/// Get the current connection state
pub fn state(&self) -> ConnectionState {
self.state
@@ -136,6 +153,109 @@ impl SigSocketClient {
pub fn url(&self) -> &str {
&self.url
}
/// Get the connected public key (if connected)
pub fn connected_public_key(&self) -> Option<&str> {
self.connected_public_key.as_deref()
}
// === Request Management Methods ===
/// Add a pending sign request
///
/// This is typically called when a sign request is received from the server.
/// The request will be stored and can be retrieved later for processing.
///
/// # Arguments
/// * `request` - The sign request to add
/// * `target_public_key` - The public key this request is intended for
pub fn add_pending_request(&mut self, request: SignRequest, target_public_key: String) {
let managed_request = ManagedSignRequest::new(request, target_public_key);
self.pending_requests.insert(managed_request.id().to_string(), managed_request);
}
/// Remove a pending request by ID
///
/// # Arguments
/// * `request_id` - The ID of the request to remove
///
/// # Returns
/// * `Some(request)` - The removed request if it existed
/// * `None` - If no request with that ID was found
pub fn remove_pending_request(&mut self, request_id: &str) -> Option<ManagedSignRequest> {
self.pending_requests.remove(request_id)
}
/// Get a pending request by ID
///
/// # Arguments
/// * `request_id` - The ID of the request to retrieve
///
/// # Returns
/// * `Some(request)` - The request if it exists
/// * `None` - If no request with that ID was found
pub fn get_pending_request(&self, request_id: &str) -> Option<&ManagedSignRequest> {
self.pending_requests.get(request_id)
}
/// Get all pending requests
///
/// # Returns
/// * A reference to the HashMap containing all pending requests
pub fn get_pending_requests(&self) -> &HashMap<String, ManagedSignRequest> {
&self.pending_requests
}
/// Get pending requests filtered by public key
///
/// # Arguments
/// * `public_key` - The public key to filter by (hex-encoded)
///
/// # Returns
/// * A vector of references to requests for the specified public key
pub fn get_requests_for_public_key(&self, public_key: &str) -> Vec<&ManagedSignRequest> {
self.pending_requests
.values()
.filter(|req| req.is_for_public_key(public_key))
.collect()
}
/// Check if a request can be handled for the given public key
///
/// This performs protocol-level validation without cryptographic operations.
///
/// # Arguments
/// * `request` - The sign request to validate
/// * `public_key` - The public key to check against (hex-encoded)
///
/// # Returns
/// * `true` - If the request can be handled for this public key
/// * `false` - If the request cannot be handled
pub fn can_handle_request_for_key(&self, request: &SignRequest, public_key: &str) -> bool {
// Basic protocol validation
if request.id.is_empty() || request.message.is_empty() {
return false;
}
// Check if we can decode the message
if request.message_bytes().is_err() {
return false;
}
// For now, we assume any valid request can be handled for any public key
// More sophisticated validation can be added here
!public_key.is_empty()
}
/// Clear all pending requests
pub fn clear_pending_requests(&mut self) {
self.pending_requests.clear();
}
/// Get the count of pending requests
pub fn pending_request_count(&self) -> usize {
self.pending_requests.len()
}
}
// Platform-specific implementations will be added in separate modules
@@ -176,6 +296,7 @@ impl SigSocketClient {
}
self.state = ConnectionState::Connected;
self.connected_public_key = Some(self.public_key_hex());
Ok(())
}
@@ -190,17 +311,19 @@ impl SigSocketClient {
}
self.inner = None;
self.state = ConnectionState::Disconnected;
self.connected_public_key = None;
self.clear_pending_requests();
Ok(())
}
/// Send a sign response to the server
///
///
/// This is typically called after the user has approved a signature request
/// and the application has generated the signature.
///
///
/// # Arguments
/// * `response` - The sign response containing the signature
///
///
/// # Returns
/// * `Ok(())` - Response sent successfully
/// * `Err(error)` - Failed to send response
@@ -215,6 +338,41 @@ impl SigSocketClient {
Err(SigSocketError::NotConnected)
}
}
/// Send a response for a specific request ID with signature
///
/// This is a convenience method that creates a SignResponse and sends it.
///
/// # Arguments
/// * `request_id` - The ID of the request being responded to
/// * `message` - The original message (base64-encoded)
/// * `signature` - The signature (base64-encoded)
///
/// # Returns
/// * `Ok(())` - Response sent successfully
/// * `Err(error)` - Failed to send response
pub async fn send_response(&self, request_id: &str, message: &str, signature: &str) -> Result<()> {
let response = SignResponse::new(request_id, message, signature);
self.send_sign_response(&response).await
}
/// Send a rejection for a specific request ID
///
/// This sends an error response to indicate the request was rejected.
///
/// # Arguments
/// * `request_id` - The ID of the request being rejected
/// * `reason` - The reason for rejection
///
/// # Returns
/// * `Ok(())` - Rejection sent successfully
/// * `Err(error)` - Failed to send rejection
pub async fn send_rejection(&self, request_id: &str, _reason: &str) -> Result<()> {
// For now, we'll send an empty signature to indicate rejection
// This can be improved with a proper rejection protocol
let response = SignResponse::new(request_id, "", "");
self.send_sign_response(&response).await
}
}
impl Drop for SigSocketClient {
@@ -222,3 +380,5 @@ impl Drop for SigSocketClient {
// Cleanup will be handled by the platform-specific implementations
}
}

View File

@@ -60,10 +60,13 @@ mod native;
mod wasm;
pub use error::{SigSocketError, Result};
pub use protocol::{SignRequest, SignResponse};
pub use protocol::{SignRequest, SignResponse, ManagedSignRequest, RequestStatus};
pub use client::{SigSocketClient, SignRequestHandler, ConnectionState};
// Re-export for convenience
pub mod prelude {
pub use crate::{SigSocketClient, SignRequest, SignResponse, SignRequestHandler, ConnectionState, SigSocketError, Result};
pub use crate::{
SigSocketClient, SignRequest, SignResponse, ManagedSignRequest, RequestStatus,
SignRequestHandler, ConnectionState, SigSocketError, Result
};
}

View File

@@ -82,6 +82,92 @@ impl SignResponse {
}
}
/// Enhanced sign request with additional metadata for request management
#[derive(Debug, Clone, Serialize, Deserialize, PartialEq)]
pub struct ManagedSignRequest {
/// The original sign request
#[serde(flatten)]
pub request: SignRequest,
/// Timestamp when the request was received (Unix timestamp in milliseconds)
pub timestamp: u64,
/// Target public key for this request (hex-encoded)
pub target_public_key: String,
/// Current status of the request
pub status: RequestStatus,
}
/// Status of a sign request
#[derive(Debug, Clone, Serialize, Deserialize, PartialEq)]
pub enum RequestStatus {
/// Request is pending user approval
Pending,
/// Request has been approved and signed
Approved,
/// Request has been rejected by user
Rejected,
/// Request has expired or been cancelled
Cancelled,
}
impl ManagedSignRequest {
/// Create a new managed sign request
pub fn new(request: SignRequest, target_public_key: String) -> Self {
Self {
request,
timestamp: current_timestamp_ms(),
target_public_key,
status: RequestStatus::Pending,
}
}
/// Get the request ID
pub fn id(&self) -> &str {
&self.request.id
}
/// Get the message as bytes (decoded from base64)
pub fn message_bytes(&self) -> Result<Vec<u8>, base64::DecodeError> {
self.request.message_bytes()
}
/// Check if this request is for the given public key
pub fn is_for_public_key(&self, public_key: &str) -> bool {
self.target_public_key == public_key
}
/// Mark the request as approved
pub fn mark_approved(&mut self) {
self.status = RequestStatus::Approved;
}
/// Mark the request as rejected
pub fn mark_rejected(&mut self) {
self.status = RequestStatus::Rejected;
}
/// Check if the request is still pending
pub fn is_pending(&self) -> bool {
matches!(self.status, RequestStatus::Pending)
}
}
/// Get current timestamp in milliseconds
#[cfg(not(target_arch = "wasm32"))]
fn current_timestamp_ms() -> u64 {
std::time::SystemTime::now()
.duration_since(std::time::UNIX_EPOCH)
.unwrap_or_default()
.as_millis() as u64
}
/// Get current timestamp in milliseconds (WASM version)
#[cfg(target_arch = "wasm32")]
fn current_timestamp_ms() -> u64 {
// In WASM, we'll use a simple counter or Date.now() via JS
// For now, return 0 - this can be improved later
0
}
#[cfg(test)]
mod tests {
use super::*;
@@ -138,4 +224,33 @@ mod tests {
let deserialized: SignResponse = serde_json::from_str(&json).unwrap();
assert_eq!(response, deserialized);
}
#[test]
fn test_managed_sign_request() {
let request = SignRequest::new("test-id", "dGVzdCBtZXNzYWdl");
let managed = ManagedSignRequest::new(request.clone(), "test-public-key".to_string());
assert_eq!(managed.id(), "test-id");
assert_eq!(managed.request, request);
assert_eq!(managed.target_public_key, "test-public-key");
assert!(managed.is_pending());
assert!(managed.is_for_public_key("test-public-key"));
assert!(!managed.is_for_public_key("other-key"));
}
#[test]
fn test_managed_request_status_changes() {
let request = SignRequest::new("test-id", "dGVzdCBtZXNzYWdl");
let mut managed = ManagedSignRequest::new(request, "test-public-key".to_string());
assert!(managed.is_pending());
managed.mark_approved();
assert_eq!(managed.status, RequestStatus::Approved);
assert!(!managed.is_pending());
managed.mark_rejected();
assert_eq!(managed.status, RequestStatus::Rejected);
assert!(!managed.is_pending());
}
}

View File

@@ -0,0 +1,92 @@
//! Tests for the enhanced request management functionality
use sigsocket_client::prelude::*;
#[test]
fn test_client_request_management() {
let public_key = hex::decode("02f9308a019258c31049344f85f89d5229b531c845836f99b08601f113bce036f9").unwrap();
let mut client = SigSocketClient::new("ws://localhost:8080/ws", public_key).unwrap();
// Initially no requests
assert_eq!(client.pending_request_count(), 0);
assert!(client.get_pending_requests().is_empty());
// Add a request
let request = SignRequest::new("test-1", "dGVzdCBtZXNzYWdl");
let public_key_hex = "02f9308a019258c31049344f85f89d5229b531c845836f99b08601f113bce036f9";
client.add_pending_request(request.clone(), public_key_hex.to_string());
// Check request was added
assert_eq!(client.pending_request_count(), 1);
assert!(client.get_pending_request("test-1").is_some());
// Check filtering by public key
let filtered = client.get_requests_for_public_key(public_key_hex);
assert_eq!(filtered.len(), 1);
assert_eq!(filtered[0].id(), "test-1");
// Add another request for different public key
let request2 = SignRequest::new("test-2", "dGVzdCBtZXNzYWdlMg==");
let other_public_key = "03f9308a019258c31049344f85f89d5229b531c845836f99b08601f113bce036f9";
client.add_pending_request(request2, other_public_key.to_string());
// Check total count
assert_eq!(client.pending_request_count(), 2);
// Check filtering still works
let filtered = client.get_requests_for_public_key(public_key_hex);
assert_eq!(filtered.len(), 1);
let filtered_other = client.get_requests_for_public_key(other_public_key);
assert_eq!(filtered_other.len(), 1);
// Remove a request
let removed = client.remove_pending_request("test-1");
assert!(removed.is_some());
assert_eq!(removed.unwrap().id(), "test-1");
assert_eq!(client.pending_request_count(), 1);
// Clear all requests
client.clear_pending_requests();
assert_eq!(client.pending_request_count(), 0);
}
#[test]
fn test_client_request_validation() {
let public_key = hex::decode("02f9308a019258c31049344f85f89d5229b531c845836f99b08601f113bce036f9").unwrap();
let client = SigSocketClient::new("ws://localhost:8080/ws", public_key).unwrap();
// Valid request
let valid_request = SignRequest::new("test-1", "dGVzdCBtZXNzYWdl");
assert!(client.can_handle_request_for_key(&valid_request, "some-public-key"));
// Invalid request - empty ID
let invalid_request = SignRequest::new("", "dGVzdCBtZXNzYWdl");
assert!(!client.can_handle_request_for_key(&invalid_request, "some-public-key"));
// Invalid request - empty message
let invalid_request2 = SignRequest::new("test-1", "");
assert!(!client.can_handle_request_for_key(&invalid_request2, "some-public-key"));
// Invalid request - invalid base64
let invalid_request3 = SignRequest::new("test-1", "invalid-base64!");
assert!(!client.can_handle_request_for_key(&invalid_request3, "some-public-key"));
// Invalid public key
assert!(!client.can_handle_request_for_key(&valid_request, ""));
}
#[test]
fn test_client_connection_state() {
let public_key = hex::decode("02f9308a019258c31049344f85f89d5229b531c845836f99b08601f113bce036f9").unwrap();
let client = SigSocketClient::new("ws://localhost:8080/ws", public_key).unwrap();
// Initially disconnected
assert_eq!(client.state(), ConnectionState::Disconnected);
assert!(!client.is_connected());
assert!(client.connected_public_key().is_none());
// Public key should be available
assert_eq!(client.public_key_hex(), "02f9308a019258c31049344f85f89d5229b531c845836f99b08601f113bce036f9");
assert_eq!(client.url(), "ws://localhost:8080/ws");
}