db/acldb/src/server.rs
2025-04-20 09:28:12 +02:00

490 lines
17 KiB
Rust

use actix_web::{web, App, HttpResponse, HttpServer, Responder};
use actix_web::middleware::Logger;
use actix_cors::Cors;
use std::collections::HashMap;
use std::sync::Arc;
use std::future;
use tokio::sync::mpsc;
use tokio::sync::RwLock;
use serde_json::json;
use log::{error};
use utoipa::OpenApi;
use utoipa_swagger_ui::SwaggerUi;
use crate::ACLDB;
use crate::rpc::{RpcInterface, RpcRequest, RpcResponse, AclUpdateParams, AclRemoveParams, AclDelParams, SetParams, DelParams, GetParams, PrefixParams};
use crate::error::Error;
use crate::utils::base64_decode;
use std::collections::VecDeque;
use tokio::task;
use tokio::time::{sleep, Duration};
/// Server configuration
#[derive(Clone)]
pub struct ServerConfig {
/// Host address
pub host: String,
/// Port number
pub port: u16,
}
impl Default for ServerConfig {
fn default() -> Self {
ServerConfig {
host: "127.0.0.1".to_string(),
port: 8080,
}
}
}
/// Request queue for a circle
struct CircleQueue {
/// Queue of pending requests
queue: VecDeque<(RpcRequest, mpsc::Sender<RpcResponse>)>,
/// Flag to indicate if a worker is currently processing this queue
is_processing: bool,
}
impl CircleQueue {
/// Creates a new circle queue
fn new() -> Self {
CircleQueue {
queue: VecDeque::new(),
is_processing: false,
}
}
/// Adds a request to the queue and starts processing if needed
async fn add_request(
&mut self,
request: RpcRequest,
response_sender: mpsc::Sender<RpcResponse>,
rpc_interface: Arc<RpcInterface>,
acldb_factory: Arc<ACLDBFactory>,
) {
// Add the request to the queue
self.queue.push_back((request.clone(), response_sender));
// If no worker is processing this queue, start one
if !self.is_processing {
self.is_processing = true;
// Clone what we need for the worker
let rpc = Arc::clone(&rpc_interface);
let factory = Arc::clone(&acldb_factory);
let mut queue = self.queue.clone();
// Spawn a worker task
task::spawn(async move {
// Process all requests in the queue
while let Some((req, sender)) = queue.pop_front() {
// Process the request
let response = process_request(&req, &rpc, &factory).await;
// Send the response
if let Err(err) = sender.send(response).await {
error!("Failed to send response: {}", err);
}
// Small delay to prevent CPU hogging
sleep(Duration::from_millis(1)).await;
}
});
}
}
}
/// Factory for creating ACLDB instances
pub struct ACLDBFactory {
/// Map of circle IDs to ACLDB instances
dbs: RwLock<HashMap<String, Arc<RwLock<ACLDB>>>>,
}
impl ACLDBFactory {
/// Creates a new ACLDBFactory
pub fn new() -> Self {
ACLDBFactory {
dbs: RwLock::new(HashMap::new()),
}
}
/// Gets or creates an ACLDB instance for a circle
pub async fn get_or_create(&self, circle_id: &str) -> Result<Arc<RwLock<ACLDB>>, Error> {
// Try to get an existing instance
let dbs = self.dbs.read().await;
if let Some(db) = dbs.get(circle_id) {
return Ok(Arc::clone(db));
}
drop(dbs); // Release the read lock
// Create a new instance
let db = Arc::new(RwLock::new(ACLDB::new(circle_id)?));
// Store it in the map
let mut dbs = self.dbs.write().await;
dbs.insert(circle_id.to_string(), Arc::clone(&db));
Ok(db)
}
}
/// Server for handling RPC requests
#[derive(Clone)]
pub struct Server {
/// Server configuration
config: ServerConfig,
/// RPC interface
rpc: Arc<RpcInterface>,
/// Map of circle IDs to request queues
queues: Arc<RwLock<HashMap<String, CircleQueue>>>,
/// Factory for creating ACLDB instances
acldb_factory: Arc<ACLDBFactory>,
}
impl Server {
/// Creates a new server
pub fn new(config: ServerConfig) -> Self {
let rpc = Arc::new(RpcInterface::new());
let queues = Arc::new(RwLock::new(HashMap::new()));
let acldb_factory = Arc::new(ACLDBFactory::new());
Server {
config,
rpc,
queues,
acldb_factory,
}
}
/// Starts the server
pub async fn start(&self) -> std::io::Result<()> {
let server_data = web::Data::new(self.clone());
// Start the HTTP server
HttpServer::new(move || {
App::new()
.wrap(Logger::default())
.wrap(
Cors::default()
.allow_any_origin()
.allow_any_method()
.allow_any_header()
.max_age(3600)
)
.app_data(web::Data::clone(&server_data))
.route("/rpc", web::post().to(handle_rpc))
.route("/health", web::get().to(health_check))
.service(
SwaggerUi::new("/swagger-ui/{_:.*}")
.url("/api-docs/openapi.json", ApiDoc::openapi())
)
})
.bind(format!("{}:{}", self.config.host, self.config.port))?
.run()
.await
}
/// Registers RPC handlers
fn register_handlers(&self) {
// Nothing to do here - handlers are now processed dynamically
}
/// Adds a request to the queue for a circle
async fn add_to_queue(&self, circle_id: &str, request: RpcRequest) -> mpsc::Receiver<RpcResponse> {
let (response_sender, response_receiver) = mpsc::channel(1);
// Get or create the queue for this circle
let mut queues = self.queues.write().await;
if !queues.contains_key(circle_id) {
queues.insert(circle_id.to_string(), CircleQueue::new());
}
// Get a mutable reference to the queue
if let Some(queue) = queues.get_mut(circle_id) {
// Add the request to the queue
queue.add_request(
request,
response_sender,
Arc::clone(&self.rpc),
Arc::clone(&self.acldb_factory)
).await;
}
response_receiver
}
}
/// Extracts the circle ID from an RPC request
fn extract_circle_id(request: &web::Json<RpcRequest>) -> Result<String, Error> {
// Extract from different parameter types based on the method
match request.method.as_str() {
"aclupdate" => {
let params: AclUpdateParams = serde_json::from_value(request.params.clone())?;
Ok(params.circle_id)
}
"aclremove" => {
let params: AclRemoveParams = serde_json::from_value(request.params.clone())?;
Ok(params.circle_id)
}
"acldel" => {
let params: AclDelParams = serde_json::from_value(request.params.clone())?;
Ok(params.circle_id)
}
"set" => {
let params: SetParams = serde_json::from_value(request.params.clone())?;
Ok(params.circle_id)
}
"del" => {
let params: DelParams = serde_json::from_value(request.params.clone())?;
Ok(params.circle_id)
}
"get" => {
let params: GetParams = serde_json::from_value(request.params.clone())?;
Ok(params.circle_id)
}
"prefix" => {
let params: PrefixParams = serde_json::from_value(request.params.clone())?;
Ok(params.circle_id)
}
_ => Err(Error::InvalidRequest(format!("Unknown method: {}", request.method))),
}
}
/// API documentation schema
#[derive(OpenApi)]
#[openapi(
paths(
health_check,
handle_rpc
),
components(
schemas(RpcRequest, RpcResponse)
),
tags(
(name = "acldb", description = "ACLDB API")
)
)]
struct ApiDoc;
/// Handler for RPC requests with OpenAPI documentation
#[utoipa::path(
post,
path = "/rpc",
request_body = RpcRequest,
responses(
(status = 200, description = "RPC request processed successfully", body = RpcResponse),
(status = 400, description = "Bad request", body = RpcResponse),
(status = 500, description = "Internal server error", body = RpcResponse)
),
tag = "acldb"
)]
async fn handle_rpc(
server: web::Data<Server>,
request: web::Json<RpcRequest>,
) -> impl Responder {
// Extract the circle ID from the request
let circle_id = match extract_circle_id(&request) {
Ok(id) => id,
Err(err) => {
return HttpResponse::BadRequest().json(RpcResponse {
result: None,
error: Some(err.to_string()),
});
}
};
// Add the request to the queue for this circle
let mut response_receiver = server.add_to_queue(&circle_id, request.0.clone()).await;
// Wait for the response
match response_receiver.recv().await {
Some(response) => HttpResponse::Ok().json(response),
None => HttpResponse::InternalServerError().json(RpcResponse {
result: None,
error: Some("Failed to get response".to_string()),
}),
}
}
/// Process an RPC request
async fn process_request(
request: &RpcRequest,
rpc_interface: &Arc<RpcInterface>,
acldb_factory: &Arc<ACLDBFactory>
) -> RpcResponse {
match request.method.as_str() {
"aclupdate" => {
match serde_json::from_value::<AclUpdateParams>(request.params.clone()) {
Ok(params) => {
match RpcInterface::parse_acl_right(&params.right) {
Ok(right) => {
match acldb_factory.get_or_create(&params.circle_id).await {
Ok(db) => {
let mut db = db.write().await;
match db.acl_update(&params.caller_pubkey, &params.name, &params.pubkeys, right).await {
Ok(_) => RpcResponse {
result: Some(json!({"success": true})),
error: None,
},
Err(err) => RpcResponse {
result: None,
error: Some(err.to_string()),
},
}
},
Err(err) => RpcResponse {
result: None,
error: Some(err.to_string()),
},
}
},
Err(err) => RpcResponse {
result: None,
error: Some(err.to_string()),
},
}
},
Err(err) => RpcResponse {
result: None,
error: Some(format!("Invalid parameters: {}", err)),
},
}
},
"aclremove" => {
match serde_json::from_value::<AclRemoveParams>(request.params.clone()) {
Ok(params) => {
match acldb_factory.get_or_create(&params.circle_id).await {
Ok(db) => {
let mut db = db.write().await;
match db.acl_remove(&params.caller_pubkey, &params.name, &params.pubkeys).await {
Ok(_) => RpcResponse {
result: Some(json!({"success": true})),
error: None,
},
Err(err) => RpcResponse {
result: None,
error: Some(err.to_string()),
},
}
},
Err(err) => RpcResponse {
result: None,
error: Some(err.to_string()),
},
}
},
Err(err) => RpcResponse {
result: None,
error: Some(format!("Invalid parameters: {}", err)),
},
}
},
"acldel" => {
match serde_json::from_value::<AclDelParams>(request.params.clone()) {
Ok(params) => {
match acldb_factory.get_or_create(&params.circle_id).await {
Ok(db) => {
let mut db = db.write().await;
match db.acl_del(&params.caller_pubkey, &params.name).await {
Ok(_) => RpcResponse {
result: Some(json!({"success": true})),
error: None,
},
Err(err) => RpcResponse {
result: None,
error: Some(err.to_string()),
},
}
},
Err(err) => RpcResponse {
result: None,
error: Some(err.to_string()),
},
}
},
Err(err) => RpcResponse {
result: None,
error: Some(format!("Invalid parameters: {}", err)),
},
}
},
"set" => {
match serde_json::from_value::<SetParams>(request.params.clone()) {
Ok(params) => {
match acldb_factory.get_or_create(&params.circle_id).await {
Ok(db) => {
let mut db = db.write().await;
let topic = db.topic(&params.topic);
match base64_decode(&params.value) {
Ok(value) => {
let acl_id = params.acl_id.unwrap_or(0);
let result = if let Some(key) = params.key {
let topic = topic.write().await;
topic.set_with_acl(&key, &value, acl_id).await
} else if let Some(id) = params.id {
let topic = topic.write().await;
topic.set_with_acl(&id.to_string(), &value, acl_id).await
} else {
// Return a future that resolves to an error for consistency
future::ready(Err(Error::InvalidRequest("Either key or id must be provided".to_string()))).await
};
match result {
Ok(id) => RpcResponse {
result: Some(json!({"id": id})),
error: None,
},
Err(err) => RpcResponse {
result: None,
error: Some(err.to_string()),
},
}
},
Err(err) => RpcResponse {
result: None,
error: Some(err.to_string()),
},
}
},
Err(err) => RpcResponse {
result: None,
error: Some(err.to_string()),
},
}
},
Err(err) => RpcResponse {
result: None,
error: Some(format!("Invalid parameters: {}", err)),
},
}
},
_ => RpcResponse {
result: None,
error: Some(format!("Unknown method: {}", request.method)),
},
}
}
/// Handler for health check with OpenAPI documentation
#[utoipa::path(
get,
path = "/health",
responses(
(status = 200, description = "Server is healthy", body = String)
),
tag = "acldb"
)]
async fn health_check() -> impl Responder {
HttpResponse::Ok().json(json!({"status": "ok"}))
}
#[cfg(test)]
mod tests {
// Tests will be added here
}