6 Commits

18 changed files with 2230 additions and 84 deletions

926
Cargo.lock generated

File diff suppressed because it is too large Load Diff

View File

@@ -24,6 +24,7 @@ age = "0.10"
secrecy = "0.8"
ed25519-dalek = "2"
base64 = "0.22"
jsonrpsee = { version = "0.26.0", features = ["http-client", "ws-client", "server", "macros"] }
[dev-dependencies]
redis = { version = "0.24", features = ["aio", "tokio-comp"] }

361
src/admin_meta.rs Normal file
View File

@@ -0,0 +1,361 @@
use std::path::PathBuf;
use std::sync::{Arc, OnceLock, Mutex, RwLock};
use std::collections::HashMap;
use crate::error::DBError;
use crate::options;
use crate::rpc::Permissions;
use crate::storage::Storage;
use crate::storage_sled::SledStorage;
use crate::storage_trait::StorageBackend;
// Key builders
fn k_admin_next_id() -> &'static str {
"admin:next_id"
}
fn k_admin_dbs() -> &'static str {
"admin:dbs"
}
fn k_meta_db(id: u64) -> String {
format!("meta:db:{}", id)
}
fn k_meta_db_keys(id: u64) -> String {
format!("meta:db:{}:keys", id)
}
fn k_meta_db_enc(id: u64) -> String {
format!("meta:db:{}:enc", id)
}
// Global cache of admin DB 0 handles per base_dir to avoid sled/reDB file-lock contention
// and to correctly isolate different test instances with distinct directories.
static ADMIN_STORAGES: OnceLock<RwLock<HashMap<String, Arc<dyn StorageBackend>>>> = OnceLock::new();
// Global registry for data DB storages to avoid double-open across process.
static DATA_STORAGES: OnceLock<RwLock<HashMap<u64, Arc<dyn StorageBackend>>>> = OnceLock::new();
static DATA_INIT_LOCK: Mutex<()> = Mutex::new(());
fn init_admin_storage(
base_dir: &str,
backend: options::BackendType,
admin_secret: &str,
) -> Result<Arc<dyn StorageBackend>, DBError> {
let db_file = PathBuf::from(base_dir).join("0.db");
if let Some(parent_dir) = db_file.parent() {
std::fs::create_dir_all(parent_dir).map_err(|e| {
DBError(format!("Failed to create directory {}: {}", parent_dir.display(), e))
})?;
}
let storage: Arc<dyn StorageBackend> = match backend {
options::BackendType::Redb => Arc::new(Storage::new(&db_file, true, Some(admin_secret))?),
options::BackendType::Sled => Arc::new(SledStorage::new(&db_file, true, Some(admin_secret))?),
};
Ok(storage)
}
// Get or initialize a cached handle to admin DB 0 per base_dir (thread-safe, no double-open race)
pub fn open_admin_storage(
base_dir: &str,
backend: options::BackendType,
admin_secret: &str,
) -> Result<Arc<dyn StorageBackend>, DBError> {
let map = ADMIN_STORAGES.get_or_init(|| RwLock::new(HashMap::new()));
// Fast path
if let Some(st) = map.read().unwrap().get(base_dir) {
return Ok(st.clone());
}
// Slow path with write lock
{
let mut w = map.write().unwrap();
if let Some(st) = w.get(base_dir) {
return Ok(st.clone());
}
let st = init_admin_storage(base_dir, backend, admin_secret)?;
w.insert(base_dir.to_string(), st.clone());
return Ok(st);
}
}
// Ensure admin structures exist in encrypted DB 0
pub fn ensure_bootstrap(
base_dir: &str,
backend: options::BackendType,
admin_secret: &str,
) -> Result<(), DBError> {
let admin = open_admin_storage(base_dir, backend, admin_secret)?;
// Initialize next id if missing
if !admin.exists(k_admin_next_id())? {
admin.set(k_admin_next_id().to_string(), "1".to_string())?;
}
// admin:dbs is a hash; it's fine if it doesn't exist (hlen -> 0)
Ok(())
}
// Get or initialize a shared handle to a data DB (> 0), avoiding double-open across subsystems
pub fn open_data_storage(
base_dir: &str,
backend: options::BackendType,
admin_secret: &str,
id: u64,
) -> Result<Arc<dyn StorageBackend>, DBError> {
if id == 0 {
return open_admin_storage(base_dir, backend, admin_secret);
}
// Validate existence in admin metadata
if !db_exists(base_dir, backend.clone(), admin_secret, id)? {
return Err(DBError(format!(
"Cannot open database instance {}, as that database instance does not exist.",
id
)));
}
let map = DATA_STORAGES.get_or_init(|| RwLock::new(HashMap::new()));
// Fast path
if let Some(st) = map.read().unwrap().get(&id) {
return Ok(st.clone());
}
// Slow path with init lock
let _guard = DATA_INIT_LOCK.lock().unwrap();
if let Some(st) = map.read().unwrap().get(&id) {
return Ok(st.clone());
}
// Determine per-db encryption
let enc = get_enc_key(base_dir, backend.clone(), admin_secret, id)?;
let should_encrypt = enc.is_some();
// Build database file path and ensure parent dir exists
let db_file = PathBuf::from(base_dir).join(format!("{}.db", id));
if let Some(parent_dir) = db_file.parent() {
std::fs::create_dir_all(parent_dir).map_err(|e| {
DBError(format!("Failed to create directory {}: {}", parent_dir.display(), e))
})?;
}
// Open storage
let storage: Arc<dyn StorageBackend> = match backend {
options::BackendType::Redb => Arc::new(Storage::new(&db_file, should_encrypt, enc.as_deref())?),
options::BackendType::Sled => Arc::new(SledStorage::new(&db_file, should_encrypt, enc.as_deref())?),
};
// Publish to registry
map.write().unwrap().insert(id, storage.clone());
Ok(storage)
}
// Allocate the next DB id and persist new pointer
pub fn allocate_next_id(
base_dir: &str,
backend: options::BackendType,
admin_secret: &str,
) -> Result<u64, DBError> {
let admin = open_admin_storage(base_dir, backend, admin_secret)?;
let cur = admin
.get(k_admin_next_id())?
.unwrap_or_else(|| "1".to_string());
let id: u64 = cur.parse().unwrap_or(1);
let next = id.checked_add(1).ok_or_else(|| DBError("next_id overflow".into()))?;
admin.set(k_admin_next_id().to_string(), next.to_string())?;
// Register into admin:dbs set/hash
let _ = admin.hset(k_admin_dbs(), vec![(id.to_string(), "1".to_string())])?;
// Default meta for the new db: public true
let meta_key = k_meta_db(id);
let _ = admin.hset(&meta_key, vec![("public".to_string(), "true".to_string())])?;
Ok(id)
}
// Check existence of a db id in admin:dbs
pub fn db_exists(
base_dir: &str,
backend: options::BackendType,
admin_secret: &str,
id: u64,
) -> Result<bool, DBError> {
let admin = open_admin_storage(base_dir, backend, admin_secret)?;
Ok(admin.hexists(k_admin_dbs(), &id.to_string())?)
}
// Get per-db encryption key, if any
pub fn get_enc_key(
base_dir: &str,
backend: options::BackendType,
admin_secret: &str,
id: u64,
) -> Result<Option<String>, DBError> {
let admin = open_admin_storage(base_dir, backend, admin_secret)?;
admin.get(&k_meta_db_enc(id))
}
// Set per-db encryption key (called during create)
pub fn set_enc_key(
base_dir: &str,
backend: options::BackendType,
admin_secret: &str,
id: u64,
key: &str,
) -> Result<(), DBError> {
let admin = open_admin_storage(base_dir, backend, admin_secret)?;
admin.set(k_meta_db_enc(id), key.to_string())
}
// Set database public flag
pub fn set_database_public(
base_dir: &str,
backend: options::BackendType,
admin_secret: &str,
id: u64,
public: bool,
) -> Result<(), DBError> {
let admin = open_admin_storage(base_dir, backend, admin_secret)?;
let mk = k_meta_db(id);
let _ = admin.hset(&mk, vec![("public".to_string(), public.to_string())])?;
Ok(())
}
// Internal: load public flag; default to true when meta missing
fn load_public(
admin: &Arc<dyn StorageBackend>,
id: u64,
) -> Result<bool, DBError> {
let mk = k_meta_db(id);
match admin.hget(&mk, "public")? {
Some(v) => Ok(v == "true"),
None => Ok(true),
}
}
// Add access key for db (value format: "Read:ts" or "ReadWrite:ts")
pub fn add_access_key(
base_dir: &str,
backend: options::BackendType,
admin_secret: &str,
id: u64,
key_plain: &str,
perms: Permissions,
) -> Result<(), DBError> {
let admin = open_admin_storage(base_dir, backend, admin_secret)?;
let hash = crate::rpc::hash_key(key_plain);
let v = match perms {
Permissions::Read => format!("Read:{}", now_secs()),
Permissions::ReadWrite => format!("ReadWrite:{}", now_secs()),
};
let _ = admin.hset(&k_meta_db_keys(id), vec![(hash, v)])?;
Ok(())
}
// Delete access key by hash
pub fn delete_access_key(
base_dir: &str,
backend: options::BackendType,
admin_secret: &str,
id: u64,
key_hash: &str,
) -> Result<bool, DBError> {
let admin = open_admin_storage(base_dir, backend, admin_secret)?;
let n = admin.hdel(&k_meta_db_keys(id), vec![key_hash.to_string()])?;
Ok(n > 0)
}
// List access keys, returning (hash, perms, created_at_secs)
pub fn list_access_keys(
base_dir: &str,
backend: options::BackendType,
admin_secret: &str,
id: u64,
) -> Result<Vec<(String, Permissions, u64)>, DBError> {
let admin = open_admin_storage(base_dir, backend, admin_secret)?;
let pairs = admin.hgetall(&k_meta_db_keys(id))?;
let mut out = Vec::new();
for (hash, val) in pairs {
let (perm, ts) = parse_perm_value(&val);
out.push((hash, perm, ts));
}
Ok(out)
}
// Verify access permission for db id with optional key
// Returns:
// - Ok(Some(Permissions)) when access is allowed
// - Ok(None) when not allowed or db missing (caller can distinguish by calling db_exists)
pub fn verify_access(
base_dir: &str,
backend: options::BackendType,
admin_secret: &str,
id: u64,
key_opt: Option<&str>,
) -> Result<Option<Permissions>, DBError> {
// Admin DB 0: require exact admin_secret
if id == 0 {
if let Some(k) = key_opt {
if k == admin_secret {
return Ok(Some(Permissions::ReadWrite));
}
}
return Ok(None);
}
let admin = open_admin_storage(base_dir, backend, admin_secret)?;
if !admin.hexists(k_admin_dbs(), &id.to_string())? {
return Ok(None);
}
// Public?
if load_public(&admin, id)? {
return Ok(Some(Permissions::ReadWrite));
}
// Private: require key and verify
if let Some(k) = key_opt {
let hash = crate::rpc::hash_key(k);
if let Some(v) = admin.hget(&k_meta_db_keys(id), &hash)? {
let (perm, _ts) = parse_perm_value(&v);
return Ok(Some(perm));
}
}
Ok(None)
}
// Enumerate all db ids
pub fn list_dbs(
base_dir: &str,
backend: options::BackendType,
admin_secret: &str,
) -> Result<Vec<u64>, DBError> {
let admin = open_admin_storage(base_dir, backend, admin_secret)?;
let ids = admin.hkeys(k_admin_dbs())?;
let mut out = Vec::new();
for s in ids {
if let Ok(v) = s.parse() {
out.push(v);
}
}
Ok(out)
}
// Helper: parse permission value "Read:ts" or "ReadWrite:ts"
fn parse_perm_value(v: &str) -> (Permissions, u64) {
let mut parts = v.split(':');
let p = parts.next().unwrap_or("Read");
let ts = parts
.next()
.and_then(|s| s.parse().ok())
.unwrap_or(0u64);
let perm = match p {
"ReadWrite" => Permissions::ReadWrite,
_ => Permissions::Read,
};
(perm, ts)
}
fn now_secs() -> u64 {
use std::time::{SystemTime, UNIX_EPOCH};
SystemTime::now()
.duration_since(UNIX_EPOCH)
.unwrap_or_default()
.as_secs()
}

View File

@@ -6,7 +6,7 @@ use futures::future::select_all;
pub enum Cmd {
Ping,
Echo(String),
Select(u64), // Changed from u16 to u64
Select(u64, Option<String>), // db_index, optional_key
Get(String),
Set(String, String),
SetPx(String, String, u128),
@@ -98,11 +98,18 @@ impl Cmd {
Ok((
match cmd[0].to_lowercase().as_str() {
"select" => {
if cmd.len() != 2 {
if cmd.len() < 2 || cmd.len() > 4 {
return Err(DBError("wrong number of arguments for SELECT".to_string()));
}
let idx = cmd[1].parse::<u64>().map_err(|_| DBError("ERR DB index is not an integer".to_string()))?;
Cmd::Select(idx)
let key = if cmd.len() == 4 && cmd[2].to_lowercase() == "key" {
Some(cmd[3].clone())
} else if cmd.len() == 2 {
None
} else {
return Err(DBError("ERR syntax error".to_string()));
};
Cmd::Select(idx, key)
}
"echo" => Cmd::Echo(cmd[1].clone()),
"ping" => Cmd::Ping,
@@ -642,7 +649,7 @@ impl Cmd {
}
match self {
Cmd::Select(db) => select_cmd(server, db).await,
Cmd::Select(db, key) => select_cmd(server, db, key).await,
Cmd::Ping => Ok(Protocol::SimpleString("PONG".to_string())),
Cmd::Echo(s) => Ok(Protocol::BulkString(s)),
Cmd::Get(k) => get_cmd(server, &k).await,
@@ -736,7 +743,14 @@ impl Cmd {
pub fn to_protocol(self) -> Protocol {
match self {
Cmd::Select(db) => Protocol::Array(vec![Protocol::BulkString("select".to_string()), Protocol::BulkString(db.to_string())]),
Cmd::Select(db, key) => {
let mut arr = vec![Protocol::BulkString("select".to_string()), Protocol::BulkString(db.to_string())];
if let Some(k) = key {
arr.push(Protocol::BulkString("key".to_string()));
arr.push(Protocol::BulkString(k));
}
Protocol::Array(arr)
}
Cmd::Ping => Protocol::Array(vec![Protocol::BulkString("ping".to_string())]),
Cmd::Echo(s) => Protocol::Array(vec![Protocol::BulkString("echo".to_string()), Protocol::BulkString(s)]),
Cmd::Get(k) => Protocol::Array(vec![Protocol::BulkString("get".to_string()), Protocol::BulkString(k)]),
@@ -753,9 +767,65 @@ async fn flushdb_cmd(server: &mut Server) -> Result<Protocol, DBError> {
}
}
async fn select_cmd(server: &mut Server, db: u64) -> Result<Protocol, DBError> {
// Test if we can access the database (this will create it if needed)
async fn select_cmd(server: &mut Server, db: u64, key: Option<String>) -> Result<Protocol, DBError> {
// Authorization and existence checks via admin DB 0
// DB 0: require KEY admin-secret
if db == 0 {
match key {
Some(k) if k == server.option.admin_secret => {
server.selected_db = 0;
server.current_permissions = Some(crate::rpc::Permissions::ReadWrite);
// Will create encrypted 0.db if missing
match server.current_storage() {
Ok(_) => return Ok(Protocol::SimpleString("OK".to_string())),
Err(e) => return Ok(Protocol::err(&e.0)),
}
}
_ => {
return Ok(Protocol::err("ERR invalid access key"));
}
}
}
// DB > 0: must exist in admin:dbs
let exists = match crate::admin_meta::db_exists(
&server.option.dir,
server.option.backend.clone(),
&server.option.admin_secret,
db,
) {
Ok(b) => b,
Err(e) => return Ok(Protocol::err(&e.0)),
};
if !exists {
return Ok(Protocol::err(&format!(
"Cannot open database instance {}, as that database instance does not exist.",
db
)));
}
// Verify permissions (public => RW; private => use key)
let perms_opt = match crate::admin_meta::verify_access(
&server.option.dir,
server.option.backend.clone(),
&server.option.admin_secret,
db,
key.as_deref(),
) {
Ok(p) => p,
Err(e) => return Ok(Protocol::err(&e.0)),
};
let perms = match perms_opt {
Some(p) => p,
None => return Ok(Protocol::err("ERR invalid access key")),
};
// Set selected database and permissions, then open storage
server.selected_db = db;
server.current_permissions = Some(perms);
match server.current_storage() {
Ok(_) => Ok(Protocol::SimpleString("OK".to_string())),
Err(e) => Ok(Protocol::err(&e.0)),
@@ -1003,6 +1073,9 @@ async fn brpop_cmd(server: &Server, keys: &[String], timeout_secs: f64) -> Resul
}
async fn lpush_cmd(server: &Server, key: &str, elements: &[String]) -> Result<Protocol, DBError> {
if !server.has_write_permission() {
return Ok(Protocol::err("ERR write permission denied"));
}
match server.current_storage()?.lpush(key, elements.to_vec()) {
Ok(len) => {
// Attempt to deliver to any blocked BLPOP waiters
@@ -1134,6 +1207,9 @@ async fn type_cmd(server: &Server, k: &String) -> Result<Protocol, DBError> {
}
async fn del_cmd(server: &Server, k: &str) -> Result<Protocol, DBError> {
if !server.has_write_permission() {
return Ok(Protocol::err("ERR write permission denied"));
}
server.current_storage()?.del(k.to_string())?;
Ok(Protocol::SimpleString("1".to_string()))
}
@@ -1159,6 +1235,9 @@ async fn set_px_cmd(
}
async fn set_cmd(server: &Server, k: &str, v: &str) -> Result<Protocol, DBError> {
if !server.has_write_permission() {
return Ok(Protocol::err("ERR write permission denied"));
}
server.current_storage()?.set(k.to_string(), v.to_string())?;
Ok(Protocol::SimpleString("OK".to_string()))
}
@@ -1273,6 +1352,9 @@ async fn get_cmd(server: &Server, k: &str) -> Result<Protocol, DBError> {
// Hash command implementations
async fn hset_cmd(server: &Server, key: &str, pairs: &[(String, String)]) -> Result<Protocol, DBError> {
if !server.has_write_permission() {
return Ok(Protocol::err("ERR write permission denied"));
}
let new_fields = server.current_storage()?.hset(key, pairs.to_vec())?;
Ok(Protocol::SimpleString(new_fields.to_string()))
}

View File

@@ -4,7 +4,10 @@ pub mod crypto;
pub mod error;
pub mod options;
pub mod protocol;
pub mod rpc;
pub mod rpc_server;
pub mod server;
pub mod storage;
pub mod storage_trait; // Add this
pub mod storage_sled; // Add this
pub mod admin_meta;

View File

@@ -3,6 +3,7 @@
use tokio::net::TcpListener;
use herodb::server;
use herodb::rpc_server;
use clap::Parser;
@@ -22,18 +23,29 @@ struct Args {
#[arg(long)]
debug: bool,
/// Master encryption key for encrypted databases
/// Master encryption key for encrypted databases (deprecated; ignored for data DBs)
#[arg(long)]
encryption_key: Option<String>,
/// Encrypt the database
/// Encrypt the database (deprecated; ignored for data DBs)
#[arg(long)]
encrypt: bool,
/// Enable RPC management server
#[arg(long)]
enable_rpc: bool,
/// RPC server port (default: 8080)
#[arg(long, default_value = "8080")]
rpc_port: u16,
/// Use the sled backend
#[arg(long)]
sled: bool,
/// Admin secret used to encrypt DB 0 and authorize admin access (required)
#[arg(long)]
admin_secret: String,
}
#[tokio::main]
@@ -48,9 +60,19 @@ async fn main() {
.await
.unwrap();
// deprecation warnings for legacy flags
if args.encrypt || args.encryption_key.is_some() {
eprintln!("warning: --encrypt and --encryption-key are deprecated and ignored for data DBs. Admin DB 0 is always encrypted with --admin-secret.");
}
// basic validation for admin secret
if args.admin_secret.trim().is_empty() {
eprintln!("error: --admin-secret must not be empty");
std::process::exit(2);
}
// new DB option
let option = herodb::options::DBOption {
dir: args.dir,
dir: args.dir.clone(),
port,
debug: args.debug,
encryption_key: args.encryption_key,
@@ -60,14 +82,42 @@ async fn main() {
} else {
herodb::options::BackendType::Redb
},
admin_secret: args.admin_secret.clone(),
};
let backend = option.backend.clone();
// Bootstrap admin DB 0 before opening any server storage
if let Err(e) = herodb::admin_meta::ensure_bootstrap(&args.dir, backend.clone(), &args.admin_secret) {
eprintln!("Failed to bootstrap admin DB 0: {}", e.0);
std::process::exit(2);
}
// new server
let server = server::Server::new(option).await;
// Add a small delay to ensure the port is ready
tokio::time::sleep(std::time::Duration::from_millis(100)).await;
// Start RPC server if enabled
let rpc_handle = if args.enable_rpc {
let rpc_addr = format!("127.0.0.1:{}", args.rpc_port).parse().unwrap();
let base_dir = args.dir.clone();
match rpc_server::start_rpc_server(rpc_addr, base_dir, backend, args.admin_secret.clone()).await {
Ok(handle) => {
println!("RPC management server started on port {}", args.rpc_port);
Some(handle)
}
Err(e) => {
eprintln!("Failed to start RPC server: {}", e);
None
}
}
} else {
None
};
// accept new connections
loop {
let stream = listener.accept().await;

View File

@@ -9,7 +9,11 @@ pub struct DBOption {
pub dir: String,
pub port: u16,
pub debug: bool,
// Deprecated for data DBs; retained for backward-compat on CLI parsing
pub encrypt: bool,
// Deprecated for data DBs; retained for backward-compat on CLI parsing
pub encryption_key: Option<String>,
pub backend: BackendType,
// New: required admin secret, used to encrypt DB 0 and authorize admin operations
pub admin_secret: String,
}

572
src/rpc.rs Normal file
View File

@@ -0,0 +1,572 @@
use std::collections::HashMap;
use std::sync::Arc;
use tokio::sync::RwLock;
use jsonrpsee::{core::RpcResult, proc_macros::rpc};
use serde::{Deserialize, Serialize};
use sha2::{Digest, Sha256};
use crate::server::Server;
use crate::options::DBOption;
use crate::admin_meta;
/// Database backend types
#[derive(Debug, Clone, Serialize, Deserialize)]
pub enum BackendType {
Redb,
Sled,
// Future: InMemory, Custom(String)
}
/// Database configuration
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct DatabaseConfig {
pub name: Option<String>,
pub storage_path: Option<String>,
pub max_size: Option<u64>,
pub redis_version: Option<String>,
}
/// Database information returned by metadata queries
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct DatabaseInfo {
pub id: u64,
pub name: Option<String>,
pub backend: BackendType,
pub encrypted: bool,
pub redis_version: Option<String>,
pub storage_path: Option<String>,
pub size_on_disk: Option<u64>,
pub key_count: Option<u64>,
pub created_at: u64,
pub last_access: Option<u64>,
}
/// Access permissions for database keys
#[derive(Debug, Clone, Serialize, Deserialize, PartialEq)]
pub enum Permissions {
Read,
ReadWrite,
}
/// Access key information
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct AccessKey {
pub hash: String,
pub permissions: Permissions,
pub created_at: u64,
}
/// Database metadata containing access keys
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct DatabaseMeta {
pub public: bool,
pub keys: HashMap<String, AccessKey>,
}
/// Access key information returned by RPC
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct AccessKeyInfo {
pub hash: String,
pub permissions: Permissions,
pub created_at: u64,
}
/// Hash a plaintext key using SHA-256
pub fn hash_key(key: &str) -> String {
let mut hasher = Sha256::new();
hasher.update(key.as_bytes());
format!("{:x}", hasher.finalize())
}
/// RPC trait for HeroDB management
#[rpc(server, client, namespace = "herodb")]
pub trait Rpc {
/// Create a new database with specified configuration
#[method(name = "createDatabase")]
async fn create_database(
&self,
backend: BackendType,
config: DatabaseConfig,
encryption_key: Option<String>,
) -> RpcResult<u64>;
/// Set encryption for an existing database (write-only key)
#[method(name = "setEncryption")]
async fn set_encryption(&self, db_id: u64, encryption_key: String) -> RpcResult<bool>;
/// List all managed databases
#[method(name = "listDatabases")]
async fn list_databases(&self) -> RpcResult<Vec<DatabaseInfo>>;
/// Get detailed information about a specific database
#[method(name = "getDatabaseInfo")]
async fn get_database_info(&self, db_id: u64) -> RpcResult<DatabaseInfo>;
/// Delete a database
#[method(name = "deleteDatabase")]
async fn delete_database(&self, db_id: u64) -> RpcResult<bool>;
/// Get server statistics
#[method(name = "getServerStats")]
async fn get_server_stats(&self) -> RpcResult<HashMap<String, serde_json::Value>>;
/// Add an access key to a database
#[method(name = "addAccessKey")]
async fn add_access_key(&self, db_id: u64, key: String, permissions: String) -> RpcResult<bool>;
/// Delete an access key from a database
#[method(name = "deleteAccessKey")]
async fn delete_access_key(&self, db_id: u64, key_hash: String) -> RpcResult<bool>;
/// List all access keys for a database
#[method(name = "listAccessKeys")]
async fn list_access_keys(&self, db_id: u64) -> RpcResult<Vec<AccessKeyInfo>>;
/// Set database public/private status
#[method(name = "setDatabasePublic")]
async fn set_database_public(&self, db_id: u64, public: bool) -> RpcResult<bool>;
}
/// RPC Server implementation
pub struct RpcServerImpl {
/// Base directory for database files
base_dir: String,
/// Managed database servers
servers: Arc<RwLock<HashMap<u64, Arc<Server>>>>,
/// Next unencrypted database ID to assign
next_unencrypted_id: Arc<RwLock<u64>>,
/// Next encrypted database ID to assign
next_encrypted_id: Arc<RwLock<u64>>,
/// Default backend type
backend: crate::options::BackendType,
/// Encryption keys for databases
encryption_keys: Arc<RwLock<HashMap<u64, Option<String>>>>,
/// Admin secret used to encrypt DB 0 and authorize admin access
admin_secret: String,
}
impl RpcServerImpl {
/// Create a new RPC server instance
pub fn new(base_dir: String, backend: crate::options::BackendType, admin_secret: String) -> Self {
Self {
base_dir,
servers: Arc::new(RwLock::new(HashMap::new())),
next_unencrypted_id: Arc::new(RwLock::new(0)),
next_encrypted_id: Arc::new(RwLock::new(10)),
backend,
encryption_keys: Arc::new(RwLock::new(HashMap::new())),
admin_secret,
}
}
/// Get or create a server instance for the given database ID
async fn get_or_create_server(&self, db_id: u64) -> Result<Arc<Server>, jsonrpsee::types::ErrorObjectOwned> {
// Check if server already exists
{
let servers = self.servers.read().await;
if let Some(server) = servers.get(&db_id) {
return Ok(server.clone());
}
}
// Validate existence via admin DB 0 (metadata), not filesystem presence
let exists = admin_meta::db_exists(&self.base_dir, self.backend.clone(), &self.admin_secret, db_id)
.map_err(|e| jsonrpsee::types::ErrorObjectOwned::owned(-32000, e.0, None::<()>))?;
if !exists {
return Err(jsonrpsee::types::ErrorObjectOwned::owned(
-32000,
format!("Database {} not found", db_id),
None::<()>
));
}
// Create server instance with default options
let db_option = DBOption {
dir: self.base_dir.clone(),
port: 0, // Not used for RPC-managed databases
debug: false,
encryption_key: None,
encrypt: false,
backend: self.backend.clone(),
admin_secret: self.admin_secret.clone(),
};
let mut server = Server::new(db_option).await;
// Set the selected database to the db_id
server.selected_db = db_id;
// Lazily open/create physical storage according to admin meta (per-db encryption)
let _ = server.current_storage();
// Store the server
let mut servers = self.servers.write().await;
servers.insert(db_id, Arc::new(server.clone()));
Ok(Arc::new(server))
}
/// Discover existing database IDs from admin DB 0
async fn discover_databases(&self) -> Vec<u64> {
admin_meta::list_dbs(&self.base_dir, self.backend.clone(), &self.admin_secret)
.unwrap_or_default()
}
/// Get the next available database ID
async fn get_next_db_id(&self, is_encrypted: bool) -> u64 {
if is_encrypted {
let mut id = self.next_encrypted_id.write().await;
let current_id = *id;
*id += 1;
current_id
} else {
let mut id = self.next_unencrypted_id.write().await;
let current_id = *id;
*id += 1;
current_id
}
}
/// Load database metadata from file (static version)
pub async fn load_meta_static(base_dir: &str, db_id: u64) -> Result<DatabaseMeta, jsonrpsee::types::ErrorObjectOwned> {
let meta_path = std::path::PathBuf::from(base_dir).join(format!("{}_meta.json", db_id));
// If meta file doesn't exist, create and persist default
if !meta_path.exists() {
let default_meta = DatabaseMeta {
public: true,
keys: HashMap::new(),
};
// Persist default metadata to disk
Self::save_meta_static(base_dir, db_id, &default_meta).await?;
return Ok(default_meta);
}
// Read file as UTF-8 JSON
let json_str = std::fs::read_to_string(&meta_path)
.map_err(|e| jsonrpsee::types::ErrorObjectOwned::owned(
-32000,
format!("Failed to read meta file: {}", e),
None::<()>
))?;
serde_json::from_str(&json_str)
.map_err(|e| jsonrpsee::types::ErrorObjectOwned::owned(
-32000,
format!("Failed to parse meta JSON: {}", e),
None::<()>
))
}
/// Load database metadata from file
async fn load_meta(&self, db_id: u64) -> Result<DatabaseMeta, jsonrpsee::types::ErrorObjectOwned> {
let meta_path = std::path::PathBuf::from(&self.base_dir).join(format!("{}_meta.json", db_id));
// If meta file doesn't exist, create and persist default
if !meta_path.exists() {
let default_meta = DatabaseMeta {
public: true,
keys: HashMap::new(),
};
self.save_meta(db_id, &default_meta).await?;
return Ok(default_meta);
}
// Read file as UTF-8 JSON (meta files are always plain JSON)
let json_str = std::fs::read_to_string(&meta_path)
.map_err(|e| jsonrpsee::types::ErrorObjectOwned::owned(
-32000,
format!("Failed to read meta file: {}", e),
None::<()>
))?;
serde_json::from_str(&json_str)
.map_err(|e| jsonrpsee::types::ErrorObjectOwned::owned(
-32000,
format!("Failed to parse meta JSON: {}", e),
None::<()>
))
}
/// Save database metadata to file (static version)
pub async fn save_meta_static(base_dir: &str, db_id: u64, meta: &DatabaseMeta) -> Result<(), jsonrpsee::types::ErrorObjectOwned> {
let meta_path = std::path::PathBuf::from(base_dir).join(format!("{}_meta.json", db_id));
let json_str = serde_json::to_string_pretty(meta)
.map_err(|e| jsonrpsee::types::ErrorObjectOwned::owned(
-32000,
format!("Failed to serialize meta: {}", e),
None::<()>
))?;
std::fs::write(&meta_path, json_str)
.map_err(|e| jsonrpsee::types::ErrorObjectOwned::owned(
-32000,
format!("Failed to write meta file: {}", e),
None::<()>
))?;
Ok(())
}
/// Save database metadata to file
async fn save_meta(&self, db_id: u64, meta: &DatabaseMeta) -> Result<(), jsonrpsee::types::ErrorObjectOwned> {
let meta_path = std::path::PathBuf::from(&self.base_dir).join(format!("{}_meta.json", db_id));
let json_str = serde_json::to_string_pretty(meta)
.map_err(|e| jsonrpsee::types::ErrorObjectOwned::owned(
-32000,
format!("Failed to serialize meta: {}", e),
None::<()>
))?;
// Meta files are always stored as plain JSON (even when data DB is encrypted)
std::fs::write(&meta_path, json_str)
.map_err(|e| jsonrpsee::types::ErrorObjectOwned::owned(
-32000,
format!("Failed to write meta file: {}", e),
None::<()>
))?;
Ok(())
}
/// Build database file path for given server/db_id
fn db_file_path(&self, server: &Server, db_id: u64) -> std::path::PathBuf {
std::path::PathBuf::from(&server.option.dir).join(format!("{}.db", db_id))
}
/// Recursively compute size on disk for the database path
fn compute_size_on_disk(&self, path: &std::path::Path) -> Option<u64> {
fn dir_size(p: &std::path::Path) -> u64 {
if p.is_file() {
std::fs::metadata(p).map(|m| m.len()).unwrap_or(0)
} else if p.is_dir() {
let mut total = 0u64;
if let Ok(read) = std::fs::read_dir(p) {
for entry in read.flatten() {
total += dir_size(&entry.path());
}
}
total
} else {
0
}
}
Some(dir_size(path))
}
/// Extract created and last access times (secs) from a path, with fallbacks
fn get_file_times_secs(path: &std::path::Path) -> (u64, Option<u64>) {
let now = std::time::SystemTime::now();
let created = std::fs::metadata(path)
.and_then(|m| m.created().or_else(|_| m.modified()))
.unwrap_or(now)
.duration_since(std::time::UNIX_EPOCH)
.unwrap_or_default()
.as_secs();
let last_access = std::fs::metadata(path)
.and_then(|m| m.accessed())
.ok()
.and_then(|t| t.duration_since(std::time::UNIX_EPOCH).ok().map(|d| d.as_secs()));
(created, last_access)
}
/// Compose a DatabaseInfo by probing storage, metadata and filesystem
async fn build_database_info(&self, db_id: u64, server: &Server) -> DatabaseInfo {
// Probe storage to determine encryption state
let storage = server.current_storage().ok();
let encrypted = storage.as_ref().map(|s| s.is_encrypted()).unwrap_or(server.option.encrypt);
// Load meta to get access key count
let meta = Self::load_meta_static(&self.base_dir, db_id).await.unwrap_or(DatabaseMeta {
public: true,
keys: HashMap::new(),
});
let key_count = Some(meta.keys.len() as u64);
// Compute size on disk and timestamps
let db_path = self.db_file_path(server, db_id);
let size_on_disk = self.compute_size_on_disk(&db_path);
let meta_path = std::path::PathBuf::from(&self.base_dir).join(format!("{}_meta.json", db_id));
let (created_at, last_access) = if meta_path.exists() {
Self::get_file_times_secs(&meta_path)
} else {
Self::get_file_times_secs(&db_path)
};
let backend = match server.option.backend {
crate::options::BackendType::Redb => BackendType::Redb,
crate::options::BackendType::Sled => BackendType::Sled,
};
DatabaseInfo {
id: db_id,
name: None,
backend,
encrypted,
redis_version: Some("7.0".to_string()),
storage_path: Some(server.option.dir.clone()),
size_on_disk,
key_count,
created_at,
last_access,
}
}
}
#[jsonrpsee::core::async_trait]
impl RpcServer for RpcServerImpl {
async fn create_database(
&self,
backend: BackendType,
_config: DatabaseConfig,
encryption_key: Option<String>,
) -> RpcResult<u64> {
// Allocate new ID via admin DB 0
let db_id = admin_meta::allocate_next_id(&self.base_dir, self.backend.clone(), &self.admin_secret)
.map_err(|e| jsonrpsee::types::ErrorObjectOwned::owned(-32000, e.0, None::<()>))?;
// Persist per-db encryption key in admin DB 0 if provided
if let Some(ref key) = encryption_key {
admin_meta::set_enc_key(&self.base_dir, self.backend.clone(), &self.admin_secret, db_id, key)
.map_err(|e| jsonrpsee::types::ErrorObjectOwned::owned(-32000, e.0, None::<()>))?;
}
// Ensure base dir exists
if let Err(e) = std::fs::create_dir_all(&self.base_dir) {
return Err(jsonrpsee::types::ErrorObjectOwned::owned(-32000, format!("Failed to ensure base dir: {}", e), None::<()>));
}
// Create server instance using base_dir and admin secret
let option = DBOption {
dir: self.base_dir.clone(),
port: 0, // Not used for RPC-managed databases
debug: false,
encryption_key: None, // per-db key is stored in admin DB 0
encrypt: false, // encryption decided per-db at open time
backend: match backend {
BackendType::Redb => crate::options::BackendType::Redb,
BackendType::Sled => crate::options::BackendType::Sled,
},
admin_secret: self.admin_secret.clone(),
};
let mut server = Server::new(option).await;
server.selected_db = db_id;
// Initialize storage to create physical <id>.db with proper encryption from admin meta
let _ = server.current_storage();
// Store the server in cache
let mut servers = self.servers.write().await;
servers.insert(db_id, Arc::new(server));
Ok(db_id)
}
async fn set_encryption(&self, db_id: u64, _encryption_key: String) -> RpcResult<bool> {
// For now, return false as encryption can only be set during creation
let _servers = self.servers.read().await;
// TODO: Implement encryption setting for existing databases
Ok(false)
}
async fn list_databases(&self) -> RpcResult<Vec<DatabaseInfo>> {
let db_ids = self.discover_databases().await;
let mut result = Vec::new();
for db_id in db_ids {
if let Ok(server) = self.get_or_create_server(db_id).await {
// Build accurate info from storage/meta/fs
let info = self.build_database_info(db_id, &server).await;
result.push(info);
}
}
Ok(result)
}
async fn get_database_info(&self, db_id: u64) -> RpcResult<DatabaseInfo> {
let server = self.get_or_create_server(db_id).await?;
// Build accurate info from storage/meta/fs
let info = self.build_database_info(db_id, &server).await;
Ok(info)
}
async fn delete_database(&self, db_id: u64) -> RpcResult<bool> {
let mut servers = self.servers.write().await;
if let Some(_server) = servers.remove(&db_id) {
// Clean up database files
let db_path = std::path::PathBuf::from(&self.base_dir).join(format!("{}.db", db_id));
if db_path.exists() {
if db_path.is_dir() {
std::fs::remove_dir_all(&db_path).ok();
} else {
std::fs::remove_file(&db_path).ok();
}
}
Ok(true)
} else {
Ok(false)
}
}
async fn get_server_stats(&self) -> RpcResult<HashMap<String, serde_json::Value>> {
let db_ids = self.discover_databases().await;
let mut stats = HashMap::new();
stats.insert("total_databases".to_string(), serde_json::json!(db_ids.len()));
stats.insert("uptime".to_string(), serde_json::json!(
std::time::SystemTime::now()
.duration_since(std::time::UNIX_EPOCH)
.unwrap()
.as_secs()
));
Ok(stats)
}
async fn add_access_key(&self, db_id: u64, key: String, permissions: String) -> RpcResult<bool> {
let perms = match permissions.to_lowercase().as_str() {
"read" => Permissions::Read,
"readwrite" => Permissions::ReadWrite,
_ => return Err(jsonrpsee::types::ErrorObjectOwned::owned(
-32000,
"Invalid permissions: use 'read' or 'readwrite'",
None::<()>
)),
};
admin_meta::add_access_key(&self.base_dir, self.backend.clone(), &self.admin_secret, db_id, &key, perms)
.map_err(|e| jsonrpsee::types::ErrorObjectOwned::owned(-32000, e.0, None::<()>))?;
Ok(true)
}
async fn delete_access_key(&self, db_id: u64, key_hash: String) -> RpcResult<bool> {
let ok = admin_meta::delete_access_key(&self.base_dir, self.backend.clone(), &self.admin_secret, db_id, &key_hash)
.map_err(|e| jsonrpsee::types::ErrorObjectOwned::owned(-32000, e.0, None::<()>))?;
Ok(ok)
}
async fn list_access_keys(&self, db_id: u64) -> RpcResult<Vec<AccessKeyInfo>> {
let pairs = admin_meta::list_access_keys(&self.base_dir, self.backend.clone(), &self.admin_secret, db_id)
.map_err(|e| jsonrpsee::types::ErrorObjectOwned::owned(-32000, e.0, None::<()>))?;
let keys: Vec<AccessKeyInfo> = pairs.into_iter().map(|(hash, perm, ts)| AccessKeyInfo {
hash,
permissions: perm,
created_at: ts,
}).collect();
Ok(keys)
}
async fn set_database_public(&self, db_id: u64, public: bool) -> RpcResult<bool> {
admin_meta::set_database_public(&self.base_dir, self.backend.clone(), &self.admin_secret, db_id, public)
.map_err(|e| jsonrpsee::types::ErrorObjectOwned::owned(-32000, e.0, None::<()>))?;
Ok(true)
}
}

49
src/rpc_server.rs Normal file
View File

@@ -0,0 +1,49 @@
use std::net::SocketAddr;
use jsonrpsee::server::{ServerBuilder, ServerHandle};
use jsonrpsee::RpcModule;
use crate::rpc::{RpcServer, RpcServerImpl};
/// Start the RPC server on the specified address
pub async fn start_rpc_server(addr: SocketAddr, base_dir: String, backend: crate::options::BackendType, admin_secret: String) -> Result<ServerHandle, Box<dyn std::error::Error + Send + Sync>> {
// Create the RPC server implementation
let rpc_impl = RpcServerImpl::new(base_dir, backend, admin_secret);
// Create the RPC module
let mut module = RpcModule::new(());
module.merge(RpcServer::into_rpc(rpc_impl))?;
// Build the server with both HTTP and WebSocket support
let server = ServerBuilder::default()
.build(addr)
.await?;
// Start the server
let handle = server.start(module);
println!("RPC server started on {}", addr);
Ok(handle)
}
#[cfg(test)]
mod tests {
use super::*;
use std::time::Duration;
#[tokio::test]
async fn test_rpc_server_startup() {
let addr = "127.0.0.1:0".parse().unwrap(); // Use port 0 for auto-assignment
let base_dir = "/tmp/test_rpc".to_string();
let backend = crate::options::BackendType::Redb; // Default for test
let handle = start_rpc_server(addr, base_dir, backend, "test-admin".to_string()).await.unwrap();
// Give the server a moment to start
tokio::time::sleep(Duration::from_millis(100)).await;
// Stop the server
handle.stop().unwrap();
handle.stopped().await;
}
}

View File

@@ -11,9 +11,8 @@ use crate::cmd::Cmd;
use crate::error::DBError;
use crate::options;
use crate::protocol::Protocol;
use crate::storage::Storage;
use crate::storage_sled::SledStorage;
use crate::storage_trait::StorageBackend;
use crate::admin_meta;
#[derive(Clone)]
pub struct Server {
@@ -22,6 +21,7 @@ pub struct Server {
pub client_name: Option<String>,
pub selected_db: u64, // Changed from usize to u64
pub queued_cmd: Option<Vec<(Cmd, Protocol)>>,
pub current_permissions: Option<crate::rpc::Permissions>,
// BLPOP waiter registry: per (db_index, key) FIFO of waiters
pub list_waiters: Arc<Mutex<HashMap<u64, HashMap<String, Vec<Waiter>>>>>,
@@ -48,6 +48,7 @@ impl Server {
client_name: None,
selected_db: 0,
queued_cmd: None,
current_permissions: None,
list_waiters: Arc::new(Mutex::new(HashMap::new())),
waiter_seq: Arc::new(AtomicU64::new(1)),
@@ -61,44 +62,37 @@ impl Server {
return Ok(storage.clone());
}
// Create new database file
let db_file_path = std::path::PathBuf::from(self.option.dir.clone())
.join(format!("{}.db", self.selected_db));
// Ensure the directory exists before creating the database file
if let Some(parent_dir) = db_file_path.parent() {
std::fs::create_dir_all(parent_dir).map_err(|e| {
DBError(format!("Failed to create directory {}: {}", parent_dir.display(), e))
})?;
}
println!("Creating new db file: {}", db_file_path.display());
let storage: Arc<dyn StorageBackend> = match self.option.backend {
options::BackendType::Redb => {
Arc::new(Storage::new(
db_file_path,
self.should_encrypt_db(self.selected_db),
self.option.encryption_key.as_deref()
)?)
}
options::BackendType::Sled => {
Arc::new(SledStorage::new(
db_file_path,
self.should_encrypt_db(self.selected_db),
self.option.encryption_key.as_deref()
)?)
}
// Use process-wide shared handles to avoid sled/reDB double-open lock contention.
let storage = if self.selected_db == 0 {
// Admin DB 0: always via singleton
admin_meta::open_admin_storage(
&self.option.dir,
self.option.backend.clone(),
&self.option.admin_secret,
)?
} else {
// Data DBs: via global registry keyed by id
admin_meta::open_data_storage(
&self.option.dir,
self.option.backend.clone(),
&self.option.admin_secret,
self.selected_db,
)?
};
cache.insert(self.selected_db, storage.clone());
Ok(storage)
}
fn should_encrypt_db(&self, db_index: u64) -> bool {
// DB 0-9 are non-encrypted, DB 10+ are encrypted
self.option.encrypt && db_index >= 10
/// Check if current permissions allow read operations
pub fn has_read_permission(&self) -> bool {
matches!(self.current_permissions, Some(crate::rpc::Permissions::Read) | Some(crate::rpc::Permissions::ReadWrite))
}
/// Check if current permissions allow write operations
pub fn has_write_permission(&self) -> bool {
matches!(self.current_permissions, Some(crate::rpc::Permissions::ReadWrite))
}
// ----- BLPOP waiter helpers -----

View File

@@ -28,6 +28,7 @@ async fn debug_hset_simple() {
encrypt: false,
encryption_key: None,
backend: herodb::options::BackendType::Redb,
admin_secret: "test-admin".to_string(),
};
let mut server = Server::new(option).await;
@@ -48,6 +49,12 @@ async fn debug_hset_simple() {
sleep(Duration::from_millis(200)).await;
let mut stream = TcpStream::connect(format!("127.0.0.1:{}", port)).await.unwrap();
// Acquire ReadWrite permissions on this connection
let resp = send_command(
&mut stream,
"*4\r\n$6\r\nSELECT\r\n$1\r\n0\r\n$3\r\nKEY\r\n$10\r\ntest-admin\r\n",
).await;
assert!(resp.contains("OK"), "Failed SELECT handshake: {}", resp);
// Test simple HSET
println!("Testing HSET...");

View File

@@ -19,6 +19,7 @@ async fn debug_hset_return_value() {
encrypt: false,
encryption_key: None,
backend: herodb::options::BackendType::Redb,
admin_secret: "test-admin".to_string(),
};
let mut server = Server::new(option).await;
@@ -41,11 +42,18 @@ async fn debug_hset_return_value() {
// Connect and test HSET
let mut stream = TcpStream::connect("127.0.0.1:16390").await.unwrap();
// Acquire ReadWrite permissions for this new connection
let handshake = "*4\r\n$6\r\nSELECT\r\n$1\r\n0\r\n$3\r\nKEY\r\n$10\r\ntest-admin\r\n";
stream.write_all(handshake.as_bytes()).await.unwrap();
let mut buffer = [0; 1024];
let n = stream.read(&mut buffer).await.unwrap();
let resp = String::from_utf8_lossy(&buffer[..n]);
assert!(resp.contains("OK"), "Failed SELECT handshake: {}", resp);
// Send HSET command
let cmd = "*4\r\n$4\r\nHSET\r\n$4\r\nhash\r\n$6\r\nfield1\r\n$6\r\nvalue1\r\n";
stream.write_all(cmd.as_bytes()).await.unwrap();
let mut buffer = [0; 1024];
let n = stream.read(&mut buffer).await.unwrap();
let response = String::from_utf8_lossy(&buffer[..n]);

View File

@@ -12,7 +12,15 @@ fn get_redis_connection(port: u16) -> Connection {
match client.get_connection() {
Ok(mut conn) => {
if redis::cmd("PING").query::<String>(&mut conn).is_ok() {
return conn;
// Acquire ReadWrite permissions on this connection
let sel: RedisResult<String> = redis::cmd("SELECT")
.arg(0)
.arg("KEY")
.arg("test-admin")
.query(&mut conn);
if sel.is_ok() {
return conn;
}
}
}
Err(e) => {
@@ -78,6 +86,8 @@ fn setup_server() -> (ServerProcessGuard, u16) {
"--port",
&port.to_string(),
"--debug",
"--admin-secret",
"test-admin",
])
.spawn()
.expect("Failed to start server process");

View File

@@ -23,18 +23,29 @@ async fn start_test_server(test_name: &str) -> (Server, u16) {
encrypt: false,
encryption_key: None,
backend: herodb::options::BackendType::Redb,
admin_secret: "test-admin".to_string(),
};
let server = Server::new(option).await;
(server, port)
}
// Helper function to connect to the test server
// Helper function to connect to the test server
async fn connect_to_server(port: u16) -> TcpStream {
let mut attempts = 0;
loop {
match TcpStream::connect(format!("127.0.0.1:{}", port)).await {
Ok(stream) => return stream,
Ok(mut stream) => {
// Obtain ReadWrite permissions for this connection by selecting DB 0 with admin key
let resp = send_command(
&mut stream,
"*4\r\n$6\r\nSELECT\r\n$1\r\n0\r\n$3\r\nKEY\r\n$10\r\ntest-admin\r\n",
).await;
if !resp.contains("OK") {
panic!("Failed to acquire write permissions via SELECT 0 KEY test-admin: {}", resp);
}
return stream;
}
Err(_) if attempts < 10 => {
attempts += 1;
sleep(Duration::from_millis(100)).await;

62
tests/rpc_tests.rs Normal file
View File

@@ -0,0 +1,62 @@
use std::net::SocketAddr;
use jsonrpsee::http_client::HttpClientBuilder;
use jsonrpsee::core::client::ClientT;
use serde_json::json;
use herodb::rpc::{RpcClient, BackendType, DatabaseConfig};
#[tokio::test]
async fn test_rpc_server_basic() {
// This test would require starting the RPC server in a separate thread
// For now, we'll just test that the types compile correctly
// Test serialization of types
let backend = BackendType::Redb;
let config = DatabaseConfig {
name: Some("test_db".to_string()),
storage_path: Some("/tmp/test".to_string()),
max_size: Some(1024 * 1024),
redis_version: Some("7.0".to_string()),
};
let backend_json = serde_json::to_string(&backend).unwrap();
let config_json = serde_json::to_string(&config).unwrap();
assert_eq!(backend_json, "\"Redb\"");
assert!(config_json.contains("test_db"));
}
#[tokio::test]
async fn test_database_config_serialization() {
let config = DatabaseConfig {
name: Some("my_db".to_string()),
storage_path: None,
max_size: Some(1000000),
redis_version: Some("7.0".to_string()),
};
let json = serde_json::to_value(&config).unwrap();
assert_eq!(json["name"], "my_db");
assert_eq!(json["max_size"], 1000000);
assert_eq!(json["redis_version"], "7.0");
}
#[tokio::test]
async fn test_backend_type_serialization() {
// Test that both Redb and Sled backends serialize correctly
let redb_backend = BackendType::Redb;
let sled_backend = BackendType::Sled;
let redb_json = serde_json::to_string(&redb_backend).unwrap();
let sled_json = serde_json::to_string(&sled_backend).unwrap();
assert_eq!(redb_json, "\"Redb\"");
assert_eq!(sled_json, "\"Sled\"");
// Test deserialization
let redb_deserialized: BackendType = serde_json::from_str(&redb_json).unwrap();
let sled_deserialized: BackendType = serde_json::from_str(&sled_json).unwrap();
assert!(matches!(redb_deserialized, BackendType::Redb));
assert!(matches!(sled_deserialized, BackendType::Sled));
}

View File

@@ -25,6 +25,7 @@ async fn start_test_server(test_name: &str) -> (Server, u16) {
encrypt: false,
encryption_key: None,
backend: herodb::options::BackendType::Redb,
admin_secret: "test-admin".to_string(),
};
let server = Server::new(option).await;
@@ -34,9 +35,16 @@ async fn start_test_server(test_name: &str) -> (Server, u16) {
// Helper function to send Redis command and get response
async fn send_redis_command(port: u16, command: &str) -> String {
let mut stream = TcpStream::connect(format!("127.0.0.1:{}", port)).await.unwrap();
// Acquire ReadWrite permissions on this new connection
let handshake = "*4\r\n$6\r\nSELECT\r\n$1\r\n0\r\n$3\r\nKEY\r\n$10\r\ntest-admin\r\n";
stream.write_all(handshake.as_bytes()).await.unwrap();
let mut buffer = [0; 1024];
let _ = stream.read(&mut buffer).await.unwrap(); // Read and ignore the OK for handshake
// Now send the intended command
stream.write_all(command.as_bytes()).await.unwrap();
let mut buffer = [0; 1024];
let n = stream.read(&mut buffer).await.unwrap();
String::from_utf8_lossy(&buffer[..n]).to_string()
}
@@ -184,12 +192,19 @@ async fn test_transaction_operations() {
sleep(Duration::from_millis(100)).await;
// Use a single connection for the transaction
// Use a single connection for the transaction
let mut stream = TcpStream::connect(format!("127.0.0.1:{}", port)).await.unwrap();
// Acquire write permissions for this connection
let handshake = "*4\r\n$6\r\nSELECT\r\n$1\r\n0\r\n$3\r\nKEY\r\n$10\r\ntest-admin\r\n";
stream.write_all(handshake.as_bytes()).await.unwrap();
let mut buffer = [0; 1024];
let n = stream.read(&mut buffer).await.unwrap();
let resp = String::from_utf8_lossy(&buffer[..n]);
assert!(resp.contains("OK"));
// Test MULTI
stream.write_all("*1\r\n$5\r\nMULTI\r\n".as_bytes()).await.unwrap();
let mut buffer = [0; 1024];
let n = stream.read(&mut buffer).await.unwrap();
let response = String::from_utf8_lossy(&buffer[..n]);
assert!(response.contains("OK"));

View File

@@ -23,6 +23,7 @@ async fn start_test_server(test_name: &str) -> (Server, u16) {
encrypt: false,
encryption_key: None,
backend: herodb::options::BackendType::Redb,
admin_secret: "test-admin".to_string(),
};
let server = Server::new(option).await;
@@ -38,12 +39,22 @@ async fn send_command(stream: &mut TcpStream, command: &str) -> String {
String::from_utf8_lossy(&buffer[..n]).to_string()
}
// Helper function to connect to the test server
// Helper function to connect to the test server
async fn connect_to_server(port: u16) -> TcpStream {
let mut attempts = 0;
loop {
match TcpStream::connect(format!("127.0.0.1:{}", port)).await {
Ok(stream) => return stream,
Ok(mut stream) => {
// Acquire ReadWrite permissions for this connection
let resp = send_command(
&mut stream,
"*4\r\n$6\r\nSELECT\r\n$1\r\n0\r\n$3\r\nKEY\r\n$10\r\ntest-admin\r\n",
).await;
if !resp.contains("OK") {
panic!("Failed to acquire write permissions via SELECT 0 KEY test-admin: {}", resp);
}
return stream;
}
Err(_) if attempts < 10 => {
attempts += 1;
sleep(Duration::from_millis(100)).await;
@@ -98,13 +109,20 @@ async fn test_hset_clean_db() {
let mut stream = connect_to_server(port).await;
// Test HSET - should return 1 for new field
let response = send_command(&mut stream, "*4\r\n$4\r\nHSET\r\n$4\r\nhash\r\n$6\r\nfield1\r\n$6\r\nvalue1\r\n").await;
// Ensure clean DB state (admin DB 0 may be shared due to global singleton)
let flush = send_command(&mut stream, "*1\r\n$7\r\nFLUSHDB\r\n").await;
assert!(flush.contains("OK"), "Failed to FLUSHDB: {}", flush);
// Test HSET - should return 1 for new field (use a unique key name to avoid collisions)
let key = "hash_clean";
let hset_cmd = format!("*4\r\n$4\r\nHSET\r\n${}\r\n{}\r\n$6\r\nfield1\r\n$6\r\nvalue1\r\n", key.len(), key);
let response = send_command(&mut stream, &hset_cmd).await;
println!("HSET response: {}", response);
assert!(response.contains("1"), "Expected HSET to return 1, got: {}", response);
// Test HGET
let response = send_command(&mut stream, "*3\r\n$4\r\nHGET\r\n$4\r\nhash\r\n$6\r\nfield1\r\n").await;
let hget_cmd = format!("*3\r\n$4\r\nHGET\r\n${}\r\n{}\r\n$6\r\nfield1\r\n", key.len(), key);
let response = send_command(&mut stream, &hget_cmd).await;
println!("HGET response: {}", response);
assert!(response.contains("value1"));
}

View File

@@ -23,6 +23,7 @@ async fn start_test_server(test_name: &str) -> (Server, u16) {
encrypt: false,
encryption_key: None,
backend: herodb::options::BackendType::Redb,
admin_secret: "test-admin".to_string(),
};
let server = Server::new(option).await;
@@ -61,7 +62,17 @@ async fn connect(port: u16) -> TcpStream {
let mut attempts = 0;
loop {
match TcpStream::connect(format!("127.0.0.1:{}", port)).await {
Ok(s) => return s,
Ok(mut s) => {
// Acquire ReadWrite permissions for this connection using admin DB 0
let resp = send_cmd(&mut s, &["SELECT", "0", "KEY", "test-admin"]).await;
assert_contains(&resp, "OK", "SELECT 0 KEY test-admin handshake");
// Ensure clean slate per test on DB 0
let fl = send_cmd(&mut s, &["FLUSHDB"]).await;
assert_contains(&fl, "OK", "FLUSHDB after handshake");
return s;
}
Err(_) if attempts < 30 => {
attempts += 1;
sleep(Duration::from_millis(100)).await;
@@ -246,9 +257,9 @@ async fn test_01_connection_and_info() {
let getname = send_cmd(&mut s, &["CLIENT", "GETNAME"]).await;
assert_contains(&getname, "myapp", "CLIENT GETNAME");
// SELECT db
let sel = send_cmd(&mut s, &["SELECT", "0"]).await;
assert_contains(&sel, "OK", "SELECT 0");
// SELECT db (requires key on DB 0)
let sel = send_cmd(&mut s, &["SELECT", "0", "KEY", "test-admin"]).await;
assert_contains(&sel, "OK", "SELECT 0 with key");
// QUIT should close connection after sending OK
let quit = send_cmd(&mut s, &["QUIT"]).await;