490 lines
17 KiB
Rust
490 lines
17 KiB
Rust
use actix_web::{web, App, HttpResponse, HttpServer, Responder};
|
|
use actix_web::middleware::Logger;
|
|
use actix_cors::Cors;
|
|
use std::collections::HashMap;
|
|
use std::sync::Arc;
|
|
use std::future;
|
|
use tokio::sync::mpsc;
|
|
use tokio::sync::RwLock;
|
|
use serde_json::json;
|
|
use log::{error};
|
|
use utoipa::OpenApi;
|
|
use utoipa_swagger_ui::SwaggerUi;
|
|
|
|
use crate::ACLDB;
|
|
use crate::rpc::{RpcInterface, RpcRequest, RpcResponse, AclUpdateParams, AclRemoveParams, AclDelParams, SetParams, DelParams, GetParams, PrefixParams};
|
|
use crate::error::Error;
|
|
use crate::utils::base64_decode;
|
|
use std::collections::VecDeque;
|
|
use tokio::task;
|
|
use tokio::time::{sleep, Duration};
|
|
|
|
/// Server configuration
|
|
#[derive(Clone)]
|
|
pub struct ServerConfig {
|
|
/// Host address
|
|
pub host: String,
|
|
/// Port number
|
|
pub port: u16,
|
|
}
|
|
|
|
impl Default for ServerConfig {
|
|
fn default() -> Self {
|
|
ServerConfig {
|
|
host: "127.0.0.1".to_string(),
|
|
port: 8080,
|
|
}
|
|
}
|
|
}
|
|
|
|
/// Request queue for a circle
|
|
struct CircleQueue {
|
|
/// Queue of pending requests
|
|
queue: VecDeque<(RpcRequest, mpsc::Sender<RpcResponse>)>,
|
|
/// Flag to indicate if a worker is currently processing this queue
|
|
is_processing: bool,
|
|
}
|
|
|
|
impl CircleQueue {
|
|
/// Creates a new circle queue
|
|
fn new() -> Self {
|
|
CircleQueue {
|
|
queue: VecDeque::new(),
|
|
is_processing: false,
|
|
}
|
|
}
|
|
|
|
/// Adds a request to the queue and starts processing if needed
|
|
async fn add_request(
|
|
&mut self,
|
|
request: RpcRequest,
|
|
response_sender: mpsc::Sender<RpcResponse>,
|
|
rpc_interface: Arc<RpcInterface>,
|
|
acldb_factory: Arc<ACLDBFactory>,
|
|
) {
|
|
// Add the request to the queue
|
|
self.queue.push_back((request.clone(), response_sender));
|
|
|
|
// If no worker is processing this queue, start one
|
|
if !self.is_processing {
|
|
self.is_processing = true;
|
|
|
|
// Clone what we need for the worker
|
|
let rpc = Arc::clone(&rpc_interface);
|
|
let factory = Arc::clone(&acldb_factory);
|
|
let mut queue = self.queue.clone();
|
|
|
|
// Spawn a worker task
|
|
task::spawn(async move {
|
|
// Process all requests in the queue
|
|
while let Some((req, sender)) = queue.pop_front() {
|
|
// Process the request
|
|
let response = process_request(&req, &rpc, &factory).await;
|
|
|
|
// Send the response
|
|
if let Err(err) = sender.send(response).await {
|
|
error!("Failed to send response: {}", err);
|
|
}
|
|
|
|
// Small delay to prevent CPU hogging
|
|
sleep(Duration::from_millis(1)).await;
|
|
}
|
|
});
|
|
}
|
|
}
|
|
}
|
|
|
|
/// Factory for creating ACLDB instances
|
|
pub struct ACLDBFactory {
|
|
/// Map of circle IDs to ACLDB instances
|
|
dbs: RwLock<HashMap<String, Arc<RwLock<ACLDB>>>>,
|
|
}
|
|
|
|
impl ACLDBFactory {
|
|
/// Creates a new ACLDBFactory
|
|
pub fn new() -> Self {
|
|
ACLDBFactory {
|
|
dbs: RwLock::new(HashMap::new()),
|
|
}
|
|
}
|
|
|
|
/// Gets or creates an ACLDB instance for a circle
|
|
pub async fn get_or_create(&self, circle_id: &str) -> Result<Arc<RwLock<ACLDB>>, Error> {
|
|
// Try to get an existing instance
|
|
let dbs = self.dbs.read().await;
|
|
if let Some(db) = dbs.get(circle_id) {
|
|
return Ok(Arc::clone(db));
|
|
}
|
|
drop(dbs); // Release the read lock
|
|
|
|
// Create a new instance
|
|
let db = Arc::new(RwLock::new(ACLDB::new(circle_id)?));
|
|
|
|
// Store it in the map
|
|
let mut dbs = self.dbs.write().await;
|
|
dbs.insert(circle_id.to_string(), Arc::clone(&db));
|
|
|
|
Ok(db)
|
|
}
|
|
}
|
|
|
|
/// Server for handling RPC requests
|
|
#[derive(Clone)]
|
|
pub struct Server {
|
|
/// Server configuration
|
|
config: ServerConfig,
|
|
/// RPC interface
|
|
rpc: Arc<RpcInterface>,
|
|
/// Map of circle IDs to request queues
|
|
queues: Arc<RwLock<HashMap<String, CircleQueue>>>,
|
|
/// Factory for creating ACLDB instances
|
|
acldb_factory: Arc<ACLDBFactory>,
|
|
}
|
|
|
|
impl Server {
|
|
/// Creates a new server
|
|
pub fn new(config: ServerConfig) -> Self {
|
|
let rpc = Arc::new(RpcInterface::new());
|
|
let queues = Arc::new(RwLock::new(HashMap::new()));
|
|
let acldb_factory = Arc::new(ACLDBFactory::new());
|
|
|
|
Server {
|
|
config,
|
|
rpc,
|
|
queues,
|
|
acldb_factory,
|
|
}
|
|
}
|
|
|
|
/// Starts the server
|
|
pub async fn start(&self) -> std::io::Result<()> {
|
|
let server_data = web::Data::new(self.clone());
|
|
|
|
// Start the HTTP server
|
|
HttpServer::new(move || {
|
|
App::new()
|
|
.wrap(Logger::default())
|
|
.wrap(
|
|
Cors::default()
|
|
.allow_any_origin()
|
|
.allow_any_method()
|
|
.allow_any_header()
|
|
.max_age(3600)
|
|
)
|
|
.app_data(web::Data::clone(&server_data))
|
|
.route("/rpc", web::post().to(handle_rpc))
|
|
.route("/health", web::get().to(health_check))
|
|
.service(
|
|
SwaggerUi::new("/swagger-ui/{_:.*}")
|
|
.url("/api-docs/openapi.json", ApiDoc::openapi())
|
|
)
|
|
})
|
|
.bind(format!("{}:{}", self.config.host, self.config.port))?
|
|
.run()
|
|
.await
|
|
}
|
|
|
|
/// Registers RPC handlers
|
|
fn register_handlers(&self) {
|
|
// Nothing to do here - handlers are now processed dynamically
|
|
}
|
|
|
|
/// Adds a request to the queue for a circle
|
|
async fn add_to_queue(&self, circle_id: &str, request: RpcRequest) -> mpsc::Receiver<RpcResponse> {
|
|
let (response_sender, response_receiver) = mpsc::channel(1);
|
|
|
|
// Get or create the queue for this circle
|
|
let mut queues = self.queues.write().await;
|
|
|
|
if !queues.contains_key(circle_id) {
|
|
queues.insert(circle_id.to_string(), CircleQueue::new());
|
|
}
|
|
|
|
// Get a mutable reference to the queue
|
|
if let Some(queue) = queues.get_mut(circle_id) {
|
|
// Add the request to the queue
|
|
queue.add_request(
|
|
request,
|
|
response_sender,
|
|
Arc::clone(&self.rpc),
|
|
Arc::clone(&self.acldb_factory)
|
|
).await;
|
|
}
|
|
|
|
response_receiver
|
|
}
|
|
}
|
|
|
|
|
|
|
|
/// Extracts the circle ID from an RPC request
|
|
fn extract_circle_id(request: &web::Json<RpcRequest>) -> Result<String, Error> {
|
|
// Extract from different parameter types based on the method
|
|
match request.method.as_str() {
|
|
"aclupdate" => {
|
|
let params: AclUpdateParams = serde_json::from_value(request.params.clone())?;
|
|
Ok(params.circle_id)
|
|
}
|
|
"aclremove" => {
|
|
let params: AclRemoveParams = serde_json::from_value(request.params.clone())?;
|
|
Ok(params.circle_id)
|
|
}
|
|
"acldel" => {
|
|
let params: AclDelParams = serde_json::from_value(request.params.clone())?;
|
|
Ok(params.circle_id)
|
|
}
|
|
"set" => {
|
|
let params: SetParams = serde_json::from_value(request.params.clone())?;
|
|
Ok(params.circle_id)
|
|
}
|
|
"del" => {
|
|
let params: DelParams = serde_json::from_value(request.params.clone())?;
|
|
Ok(params.circle_id)
|
|
}
|
|
"get" => {
|
|
let params: GetParams = serde_json::from_value(request.params.clone())?;
|
|
Ok(params.circle_id)
|
|
}
|
|
"prefix" => {
|
|
let params: PrefixParams = serde_json::from_value(request.params.clone())?;
|
|
Ok(params.circle_id)
|
|
}
|
|
_ => Err(Error::InvalidRequest(format!("Unknown method: {}", request.method))),
|
|
}
|
|
}
|
|
|
|
/// API documentation schema
|
|
#[derive(OpenApi)]
|
|
#[openapi(
|
|
paths(
|
|
health_check,
|
|
handle_rpc
|
|
),
|
|
components(
|
|
schemas(RpcRequest, RpcResponse)
|
|
),
|
|
tags(
|
|
(name = "acldb", description = "ACLDB API")
|
|
)
|
|
)]
|
|
struct ApiDoc;
|
|
|
|
/// Handler for RPC requests with OpenAPI documentation
|
|
#[utoipa::path(
|
|
post,
|
|
path = "/rpc",
|
|
request_body = RpcRequest,
|
|
responses(
|
|
(status = 200, description = "RPC request processed successfully", body = RpcResponse),
|
|
(status = 400, description = "Bad request", body = RpcResponse),
|
|
(status = 500, description = "Internal server error", body = RpcResponse)
|
|
),
|
|
tag = "acldb"
|
|
)]
|
|
async fn handle_rpc(
|
|
server: web::Data<Server>,
|
|
request: web::Json<RpcRequest>,
|
|
) -> impl Responder {
|
|
// Extract the circle ID from the request
|
|
let circle_id = match extract_circle_id(&request) {
|
|
Ok(id) => id,
|
|
Err(err) => {
|
|
return HttpResponse::BadRequest().json(RpcResponse {
|
|
result: None,
|
|
error: Some(err.to_string()),
|
|
});
|
|
}
|
|
};
|
|
|
|
// Add the request to the queue for this circle
|
|
let mut response_receiver = server.add_to_queue(&circle_id, request.0.clone()).await;
|
|
|
|
// Wait for the response
|
|
match response_receiver.recv().await {
|
|
Some(response) => HttpResponse::Ok().json(response),
|
|
None => HttpResponse::InternalServerError().json(RpcResponse {
|
|
result: None,
|
|
error: Some("Failed to get response".to_string()),
|
|
}),
|
|
}
|
|
}
|
|
|
|
/// Process an RPC request
|
|
async fn process_request(
|
|
request: &RpcRequest,
|
|
rpc_interface: &Arc<RpcInterface>,
|
|
acldb_factory: &Arc<ACLDBFactory>
|
|
) -> RpcResponse {
|
|
match request.method.as_str() {
|
|
"aclupdate" => {
|
|
match serde_json::from_value::<AclUpdateParams>(request.params.clone()) {
|
|
Ok(params) => {
|
|
match RpcInterface::parse_acl_right(¶ms.right) {
|
|
Ok(right) => {
|
|
match acldb_factory.get_or_create(¶ms.circle_id).await {
|
|
Ok(db) => {
|
|
let mut db = db.write().await;
|
|
match db.acl_update(¶ms.caller_pubkey, ¶ms.name, ¶ms.pubkeys, right).await {
|
|
Ok(_) => RpcResponse {
|
|
result: Some(json!({"success": true})),
|
|
error: None,
|
|
},
|
|
Err(err) => RpcResponse {
|
|
result: None,
|
|
error: Some(err.to_string()),
|
|
},
|
|
}
|
|
},
|
|
Err(err) => RpcResponse {
|
|
result: None,
|
|
error: Some(err.to_string()),
|
|
},
|
|
}
|
|
},
|
|
Err(err) => RpcResponse {
|
|
result: None,
|
|
error: Some(err.to_string()),
|
|
},
|
|
}
|
|
},
|
|
Err(err) => RpcResponse {
|
|
result: None,
|
|
error: Some(format!("Invalid parameters: {}", err)),
|
|
},
|
|
}
|
|
},
|
|
"aclremove" => {
|
|
match serde_json::from_value::<AclRemoveParams>(request.params.clone()) {
|
|
Ok(params) => {
|
|
match acldb_factory.get_or_create(¶ms.circle_id).await {
|
|
Ok(db) => {
|
|
let mut db = db.write().await;
|
|
match db.acl_remove(¶ms.caller_pubkey, ¶ms.name, ¶ms.pubkeys).await {
|
|
Ok(_) => RpcResponse {
|
|
result: Some(json!({"success": true})),
|
|
error: None,
|
|
},
|
|
Err(err) => RpcResponse {
|
|
result: None,
|
|
error: Some(err.to_string()),
|
|
},
|
|
}
|
|
},
|
|
Err(err) => RpcResponse {
|
|
result: None,
|
|
error: Some(err.to_string()),
|
|
},
|
|
}
|
|
},
|
|
Err(err) => RpcResponse {
|
|
result: None,
|
|
error: Some(format!("Invalid parameters: {}", err)),
|
|
},
|
|
}
|
|
},
|
|
"acldel" => {
|
|
match serde_json::from_value::<AclDelParams>(request.params.clone()) {
|
|
Ok(params) => {
|
|
match acldb_factory.get_or_create(¶ms.circle_id).await {
|
|
Ok(db) => {
|
|
let mut db = db.write().await;
|
|
match db.acl_del(¶ms.caller_pubkey, ¶ms.name).await {
|
|
Ok(_) => RpcResponse {
|
|
result: Some(json!({"success": true})),
|
|
error: None,
|
|
},
|
|
Err(err) => RpcResponse {
|
|
result: None,
|
|
error: Some(err.to_string()),
|
|
},
|
|
}
|
|
},
|
|
Err(err) => RpcResponse {
|
|
result: None,
|
|
error: Some(err.to_string()),
|
|
},
|
|
}
|
|
},
|
|
Err(err) => RpcResponse {
|
|
result: None,
|
|
error: Some(format!("Invalid parameters: {}", err)),
|
|
},
|
|
}
|
|
},
|
|
"set" => {
|
|
match serde_json::from_value::<SetParams>(request.params.clone()) {
|
|
Ok(params) => {
|
|
match acldb_factory.get_or_create(¶ms.circle_id).await {
|
|
Ok(db) => {
|
|
let mut db = db.write().await;
|
|
let topic = db.topic(¶ms.topic);
|
|
|
|
match base64_decode(¶ms.value) {
|
|
Ok(value) => {
|
|
let acl_id = params.acl_id.unwrap_or(0);
|
|
|
|
let result = if let Some(key) = params.key {
|
|
let topic = topic.write().await;
|
|
topic.set_with_acl(&key, &value, acl_id).await
|
|
} else if let Some(id) = params.id {
|
|
let topic = topic.write().await;
|
|
topic.set_with_acl(&id.to_string(), &value, acl_id).await
|
|
} else {
|
|
// Return a future that resolves to an error for consistency
|
|
future::ready(Err(Error::InvalidRequest("Either key or id must be provided".to_string()))).await
|
|
};
|
|
|
|
match result {
|
|
Ok(id) => RpcResponse {
|
|
result: Some(json!({"id": id})),
|
|
error: None,
|
|
},
|
|
Err(err) => RpcResponse {
|
|
result: None,
|
|
error: Some(err.to_string()),
|
|
},
|
|
}
|
|
},
|
|
Err(err) => RpcResponse {
|
|
result: None,
|
|
error: Some(err.to_string()),
|
|
},
|
|
}
|
|
},
|
|
Err(err) => RpcResponse {
|
|
result: None,
|
|
error: Some(err.to_string()),
|
|
},
|
|
}
|
|
},
|
|
Err(err) => RpcResponse {
|
|
result: None,
|
|
error: Some(format!("Invalid parameters: {}", err)),
|
|
},
|
|
}
|
|
},
|
|
_ => RpcResponse {
|
|
result: None,
|
|
error: Some(format!("Unknown method: {}", request.method)),
|
|
},
|
|
}
|
|
}
|
|
|
|
/// Handler for health check with OpenAPI documentation
|
|
#[utoipa::path(
|
|
get,
|
|
path = "/health",
|
|
responses(
|
|
(status = 200, description = "Server is healthy", body = String)
|
|
),
|
|
tag = "acldb"
|
|
)]
|
|
async fn health_check() -> impl Responder {
|
|
HttpResponse::Ok().json(json!({"status": "ok"}))
|
|
}
|
|
|
|
#[cfg(test)]
|
|
mod tests {
|
|
// Tests will be added here
|
|
}
|