WIP: adding access management control to db instances

This commit is contained in:
Maxime Van Hees
2025-09-12 17:11:50 +02:00
parent 8798bc202e
commit bdf363016a
3 changed files with 373 additions and 8 deletions

View File

@@ -6,7 +6,7 @@ use futures::future::select_all;
pub enum Cmd {
Ping,
Echo(String),
Select(u64), // Changed from u16 to u64
Select(u64, Option<String>), // db_index, optional_key
Get(String),
Set(String, String),
SetPx(String, String, u128),
@@ -98,11 +98,18 @@ impl Cmd {
Ok((
match cmd[0].to_lowercase().as_str() {
"select" => {
if cmd.len() != 2 {
if cmd.len() < 2 || cmd.len() > 4 {
return Err(DBError("wrong number of arguments for SELECT".to_string()));
}
let idx = cmd[1].parse::<u64>().map_err(|_| DBError("ERR DB index is not an integer".to_string()))?;
Cmd::Select(idx)
let key = if cmd.len() == 4 && cmd[2].to_lowercase() == "key" {
Some(cmd[3].clone())
} else if cmd.len() == 2 {
None
} else {
return Err(DBError("ERR syntax error".to_string()));
};
Cmd::Select(idx, key)
}
"echo" => Cmd::Echo(cmd[1].clone()),
"ping" => Cmd::Ping,
@@ -642,7 +649,7 @@ impl Cmd {
}
match self {
Cmd::Select(db) => select_cmd(server, db).await,
Cmd::Select(db, key) => select_cmd(server, db, key).await,
Cmd::Ping => Ok(Protocol::SimpleString("PONG".to_string())),
Cmd::Echo(s) => Ok(Protocol::BulkString(s)),
Cmd::Get(k) => get_cmd(server, &k).await,
@@ -736,7 +743,14 @@ impl Cmd {
pub fn to_protocol(self) -> Protocol {
match self {
Cmd::Select(db) => Protocol::Array(vec![Protocol::BulkString("select".to_string()), Protocol::BulkString(db.to_string())]),
Cmd::Select(db, key) => {
let mut arr = vec![Protocol::BulkString("select".to_string()), Protocol::BulkString(db.to_string())];
if let Some(k) = key {
arr.push(Protocol::BulkString("key".to_string()));
arr.push(Protocol::BulkString(k));
}
Protocol::Array(arr)
}
Cmd::Ping => Protocol::Array(vec![Protocol::BulkString("ping".to_string())]),
Cmd::Echo(s) => Protocol::Array(vec![Protocol::BulkString("echo".to_string()), Protocol::BulkString(s)]),
Cmd::Get(k) => Protocol::Array(vec![Protocol::BulkString("get".to_string()), Protocol::BulkString(k)]),
@@ -753,9 +767,44 @@ async fn flushdb_cmd(server: &mut Server) -> Result<Protocol, DBError> {
}
}
async fn select_cmd(server: &mut Server, db: u64) -> Result<Protocol, DBError> {
// Test if we can access the database (this will create it if needed)
async fn select_cmd(server: &mut Server, db: u64, key: Option<String>) -> Result<Protocol, DBError> {
// Load database metadata
let meta = match crate::rpc::RpcServerImpl::load_meta_static(&server.option.dir, db).await {
Ok(m) => m,
Err(_) => {
// If meta doesn't exist, create default
let default_meta = crate::rpc::DatabaseMeta {
public: true,
keys: std::collections::HashMap::new(),
};
if let Err(_) = crate::rpc::RpcServerImpl::save_meta_static(&server.option.dir, db, &default_meta).await {
return Ok(Protocol::err("ERR failed to initialize database metadata"));
}
default_meta
}
};
// Check access permissions
let permissions = if meta.public {
// Public database - full access
Some(crate::rpc::Permissions::ReadWrite)
} else if let Some(key_str) = key {
// Private database - check key
let hash = crate::rpc::hash_key(&key_str);
if let Some(access_key) = meta.keys.get(&hash) {
Some(access_key.permissions.clone())
} else {
return Ok(Protocol::err("ERR invalid access key"));
}
} else {
return Ok(Protocol::err("ERR access key required for private database"));
};
// Set selected database and permissions
server.selected_db = db;
server.current_permissions = permissions;
// Test if we can access the database (this will create it if needed)
match server.current_storage() {
Ok(_) => Ok(Protocol::SimpleString("OK".to_string())),
Err(e) => Ok(Protocol::err(&e.0)),
@@ -1003,6 +1052,9 @@ async fn brpop_cmd(server: &Server, keys: &[String], timeout_secs: f64) -> Resul
}
async fn lpush_cmd(server: &Server, key: &str, elements: &[String]) -> Result<Protocol, DBError> {
if !server.has_write_permission() {
return Ok(Protocol::err("ERR write permission denied"));
}
match server.current_storage()?.lpush(key, elements.to_vec()) {
Ok(len) => {
// Attempt to deliver to any blocked BLPOP waiters
@@ -1134,6 +1186,9 @@ async fn type_cmd(server: &Server, k: &String) -> Result<Protocol, DBError> {
}
async fn del_cmd(server: &Server, k: &str) -> Result<Protocol, DBError> {
if !server.has_write_permission() {
return Ok(Protocol::err("ERR write permission denied"));
}
server.current_storage()?.del(k.to_string())?;
Ok(Protocol::SimpleString("1".to_string()))
}
@@ -1159,6 +1214,9 @@ async fn set_px_cmd(
}
async fn set_cmd(server: &Server, k: &str, v: &str) -> Result<Protocol, DBError> {
if !server.has_write_permission() {
return Ok(Protocol::err("ERR write permission denied"));
}
server.current_storage()?.set(k.to_string(), v.to_string())?;
Ok(Protocol::SimpleString("OK".to_string()))
}
@@ -1273,6 +1331,9 @@ async fn get_cmd(server: &Server, k: &str) -> Result<Protocol, DBError> {
// Hash command implementations
async fn hset_cmd(server: &Server, key: &str, pairs: &[(String, String)]) -> Result<Protocol, DBError> {
if !server.has_write_permission() {
return Ok(Protocol::err("ERR write permission denied"));
}
let new_fields = server.current_storage()?.hset(key, pairs.to_vec())?;
Ok(Protocol::SimpleString(new_fields.to_string()))
}