This commit is contained in:
2025-08-16 09:50:56 +02:00
parent 0000d82799
commit dbd0635cd9
11 changed files with 139 additions and 142 deletions

View File

@@ -1,4 +1,5 @@
use crate::{error::DBError, protocol::Protocol, server::Server};
use serde::Serialize;
#[derive(Debug, Clone)]
pub enum Cmd {
@@ -444,7 +445,7 @@ impl Cmd {
Cmd::Del(k) => del_cmd(server, k).await,
Cmd::ConfigGet(name) => config_get_cmd(name, server),
Cmd::Keys => keys_cmd(server).await,
Cmd::Info(section) => info_cmd(section),
Cmd::Info(section) => info_cmd(server, section).await,
Cmd::Type(k) => type_cmd(server, k).await,
Cmd::Incr(key) => incr_cmd(server, key).await,
Cmd::Multi => {
@@ -640,19 +641,23 @@ async fn incr_cmd(server: &Server, key: &String) -> Result<Protocol, DBError> {
}
fn config_get_cmd(name: &String, server: &Server) -> Result<Protocol, DBError> {
let mut result = Vec::new();
result.push(Protocol::BulkString(name.clone()));
match name.as_str() {
"dir" => Ok(Protocol::Array(vec![
Protocol::BulkString(name.clone()),
Protocol::BulkString(server.option.dir.clone()),
])),
"dbfilename" => Ok(Protocol::Array(vec![
Protocol::BulkString(name.clone()),
Protocol::BulkString(format!("{}.db", server.selected_db)),
])),
"databases" => Ok(Protocol::Array(vec![
Protocol::BulkString(name.clone()),
Protocol::BulkString(server.option.max_databases.unwrap_or(0).to_string()),
])),
"dir" => {
result.push(Protocol::BulkString(server.option.dir.clone()));
Ok(Protocol::Array(result))
}
"dbfilename" => {
result.push(Protocol::BulkString(format!("{}.db", server.selected_db)));
Ok(Protocol::Array(result))
},
"databases" => {
// This is hardcoded, as the feature was removed
result.push(Protocol::BulkString("16".to_string()));
Ok(Protocol::Array(result))
},
_ => Ok(Protocol::Array(vec![])),
}
}
@@ -664,7 +669,24 @@ async fn keys_cmd(server: &Server) -> Result<Protocol, DBError> {
))
}
fn info_cmd(section: &Option<String>) -> Result<Protocol, DBError> {
#[derive(Serialize)]
struct ServerInfo {
redis_version: String,
encrypted: bool,
}
async fn info_cmd(server: &Server, section: &Option<String>) -> Result<Protocol, DBError> {
let info = ServerInfo {
redis_version: "7.0.0".to_string(),
encrypted: server.current_storage()?.is_encrypted(),
};
let mut info_string = String::new();
info_string.push_str(&format!("# Server\n"));
info_string.push_str(&format!("redis_version:{}\n", info.redis_version));
info_string.push_str(&format!("encrypted:{}\n", if info.encrypted { 1 } else { 0 }));
match section {
Some(s) => match s.as_str() {
"replication" => Ok(Protocol::BulkString(
@@ -672,7 +694,9 @@ fn info_cmd(section: &Option<String>) -> Result<Protocol, DBError> {
)),
_ => Err(DBError(format!("unsupported section {:?}", s))),
},
None => Ok(Protocol::BulkString("# Server\nredis_version:7.0.0\n".to_string())),
None => {
Ok(Protocol::BulkString(info_string))
}
}
}

View File

@@ -22,13 +22,14 @@ struct Args {
#[arg(long)]
debug: bool,
/// Maximum number of logical databases (None = unlimited)
#[arg(long)]
max_databases: Option<u64>,
/// Master encryption key for encrypted databases
#[arg(long)]
encryption_key: Option<String>,
/// Encrypt the database
#[arg(long)]
encrypt: bool,
}
#[tokio::main]
@@ -48,8 +49,8 @@ async fn main() {
dir: args.dir,
port,
debug: args.debug,
max_databases: args.max_databases,
encryption_key: args.encryption_key,
encrypt: args.encrypt,
};
// new server

View File

@@ -3,6 +3,6 @@ pub struct DBOption {
pub dir: String,
pub port: u16,
pub debug: bool,
pub max_databases: Option<u64>, // None = unlimited, Some(n) = limit to n
pub encrypt: bool,
pub encryption_key: Option<String>, // Master encryption key
}

View File

@@ -35,12 +35,6 @@ impl Server {
return Ok(storage.clone());
}
// Check database limit if set
if let Some(max_db) = self.option.max_databases {
if self.selected_db >= max_db {
return Err(DBError(format!("DB index {} is out of range (max: {})", self.selected_db, max_db - 1)));
}
}
// Create new database file
let db_file_path = std::path::PathBuf::from(self.option.dir.clone())
@@ -58,10 +52,8 @@ impl Server {
Ok(storage)
}
fn should_encrypt_db(&self, db_index: u64) -> bool {
// You can implement logic here to determine which databases should be encrypted
// For now, let's say databases with even numbers are encrypted if key is provided
self.option.encryption_key.is_some() && db_index % 2 == 0
fn should_encrypt_db(&self, _db_index: u64) -> bool {
self.option.encrypt
}
pub async fn handle(

View File

@@ -104,17 +104,12 @@ fn glob_match(pattern: &str, text: &str) -> bool {
// Table definitions for different Redis data types
const TYPES_TABLE: TableDefinition<&str, &str> = TableDefinition::new("types");
const STRINGS_TABLE: TableDefinition<&str, &[u8]> = TableDefinition::new("strings");
const HASHES_TABLE: TableDefinition<(&str, &str), &str> = TableDefinition::new("hashes");
const HASHES_TABLE: TableDefinition<(&str, &str), &[u8]> = TableDefinition::new("hashes");
const LISTS_TABLE: TableDefinition<&str, &[u8]> = TableDefinition::new("lists");
const STREAMS_META_TABLE: TableDefinition<&str, &[u8]> = TableDefinition::new("streams_meta");
const STREAMS_DATA_TABLE: TableDefinition<(&str, &str), &[u8]> = TableDefinition::new("streams_data");
const ENCRYPTED_TABLE: TableDefinition<&str, u8> = TableDefinition::new("encrypted");
#[derive(Serialize, Deserialize, Debug, Clone)]
pub struct StringValue {
pub value: String,
pub expires_at_ms: Option<u128>,
}
const EXPIRATION_TABLE: TableDefinition<&str, u64> = TableDefinition::new("expiration");
#[derive(Serialize, Deserialize, Debug, Clone)]
pub struct StreamEntry {
@@ -153,6 +148,7 @@ impl Storage {
let _ = write_txn.open_table(STREAMS_META_TABLE)?;
let _ = write_txn.open_table(STREAMS_DATA_TABLE)?;
let _ = write_txn.open_table(ENCRYPTED_TABLE)?;
let _ = write_txn.open_table(EXPIRATION_TABLE)?;
}
write_txn.commit()?;
@@ -218,6 +214,7 @@ impl Storage {
let mut lists_table = write_txn.open_table(LISTS_TABLE)?;
let mut streams_meta_table = write_txn.open_table(STREAMS_META_TABLE)?;
let mut streams_data_table = write_txn.open_table(STREAMS_DATA_TABLE)?;
let mut expiration_table = write_txn.open_table(EXPIRATION_TABLE)?;
// inefficient, but there is no other way
let keys: Vec<String> = types_table.iter()?.map(|item| item.unwrap().0.value().to_string()).collect();
@@ -252,6 +249,10 @@ impl Storage {
for (key, field) in keys {
streams_data_table.remove((key.as_str(), field.as_str()))?;
}
let keys: Vec<String> = expiration_table.iter()?.map(|item| item.unwrap().0.value().to_string()).collect();
for key in keys {
expiration_table.remove(key.as_str())?;
}
}
write_txn.commit()?;
Ok(())
@@ -274,22 +275,23 @@ impl Storage {
let types_table = read_txn.open_table(TYPES_TABLE)?;
match types_table.get(key)? {
Some(type_val) if type_val.value() == "string" => {
// Check expiration first (unencrypted)
let expiration_table = read_txn.open_table(EXPIRATION_TABLE)?;
if let Some(expires_at) = expiration_table.get(key)? {
if now_in_millis() > expires_at.value() as u128 {
drop(read_txn);
self.del(key.to_string())?;
return Ok(None);
}
}
// Get and decrypt value
let strings_table = read_txn.open_table(STRINGS_TABLE)?;
match strings_table.get(key)? {
Some(data) => {
let decrypted = self.decrypt_if_needed(data.value())?;
let string_value: StringValue = bincode::deserialize(&decrypted)?;
// Check if expired
if let Some(expires_at) = string_value.expires_at_ms {
if now_in_millis() > expires_at {
drop(read_txn);
self.del(key.to_string())?;
return Ok(None);
}
}
Ok(Some(string_value.value))
let value = String::from_utf8(decrypted)?;
Ok(Some(value))
}
None => Ok(None),
}
@@ -310,13 +312,13 @@ impl Storage {
types_table.insert(key.as_str(), "string")?;
let mut strings_table = write_txn.open_table(STRINGS_TABLE)?;
let string_value = StringValue {
value,
expires_at_ms: None,
};
let serialized = bincode::serialize(&string_value)?;
let encrypted = self.encrypt_if_needed(&serialized)?;
// Only encrypt the value, not expiration
let encrypted = self.encrypt_if_needed(value.as_bytes())?;
strings_table.insert(key.as_str(), encrypted.as_slice())?;
// Remove any existing expiration since this is a regular SET
let mut expiration_table = write_txn.open_table(EXPIRATION_TABLE)?;
expiration_table.remove(key.as_str())?;
}
write_txn.commit()?;
@@ -331,13 +333,14 @@ impl Storage {
types_table.insert(key.as_str(), "string")?;
let mut strings_table = write_txn.open_table(STRINGS_TABLE)?;
let string_value = StringValue {
value,
expires_at_ms: Some(expire_ms + now_in_millis()),
};
let serialized = bincode::serialize(&string_value)?;
let encrypted = self.encrypt_if_needed(&serialized)?;
// Only encrypt the value
let encrypted = self.encrypt_if_needed(value.as_bytes())?;
strings_table.insert(key.as_str(), encrypted.as_slice())?;
// Store expiration separately (unencrypted)
let mut expiration_table = write_txn.open_table(EXPIRATION_TABLE)?;
let expires_at = expire_ms + now_in_millis();
expiration_table.insert(key.as_str(), &(expires_at as u64))?;
}
write_txn.commit()?;
@@ -377,6 +380,10 @@ impl Storage {
// Remove from lists table
lists_table.remove(key.as_str())?;
// Also remove expiration
let mut expiration_table = write_txn.open_table(EXPIRATION_TABLE)?;
expiration_table.remove(key.as_str())?;
}
write_txn.commit()?;
@@ -427,7 +434,11 @@ impl Storage {
for (field, value) in pairs {
let existed = hashes_table.get((key, field.as_str()))?.is_some();
hashes_table.insert((key, field.as_str()), value.as_str())?;
// Encrypt the value before storing
let encrypted = self.encrypt_if_needed(value.as_bytes())?;
hashes_table.insert((key, field.as_str()), encrypted.as_slice())?;
if !existed {
new_fields += 1;
}
@@ -447,7 +458,11 @@ impl Storage {
Some(type_val) if type_val.value() == "hash" => {
let hashes_table = read_txn.open_table(HASHES_TABLE)?;
match hashes_table.get((key, field))? {
Some(value) => Ok(Some(value.value().to_string())),
Some(data) => {
let decrypted = self.decrypt_if_needed(data.value())?;
let value = String::from_utf8(decrypted)?;
Ok(Some(value))
}
None => Ok(None),
}
}
@@ -472,7 +487,9 @@ impl Storage {
let (hash_key, field) = entry.0.value();
let value = entry.1.value();
if hash_key == key {
result.push((field.to_string(), value.to_string()));
let decrypted = self.decrypt_if_needed(value)?;
let value_str = String::from_utf8(decrypted)?;
result.push((field.to_string(), value_str));
}
}
@@ -563,7 +580,9 @@ impl Storage {
let (hash_key, _) = entry.0.value();
let value = entry.1.value();
if hash_key == key {
result.push(value.to_string());
let decrypted = self.decrypt_if_needed(value)?;
let value_str = String::from_utf8(decrypted)?;
result.push(value_str);
}
}
@@ -610,7 +629,11 @@ impl Storage {
for field in fields {
match hashes_table.get((key, field.as_str()))? {
Some(value) => result.push(Some(value.value().to_string())),
Some(data) => {
let decrypted = self.decrypt_if_needed(data.value())?;
let value = String::from_utf8(decrypted)?;
result.push(Some(value));
}
None => result.push(None),
}
}
@@ -649,7 +672,8 @@ impl Storage {
// Check if field already exists
if hashes_table.get((key, field))?.is_none() {
hashes_table.insert((key, field), value)?;
let encrypted_value = self.encrypt_if_needed(value.as_bytes())?;
hashes_table.insert((key, field), encrypted_value.as_slice())?;
result = true;
}
}
@@ -769,8 +793,10 @@ impl Storage {
};
if matches {
let decrypted = self.decrypt_if_needed(value)?;
let value_str = String::from_utf8(decrypted)?;
fields.push(field.to_string());
fields.push(value.to_string());
fields.push(value_str);
returned_fields += 1;
if returned_fields >= count {
@@ -792,30 +818,24 @@ impl Storage {
pub fn ttl(&self, key: &str) -> Result<i64, DBError> {
let read_txn = self.db.begin_read()?;
// Check if key exists
let types_table = read_txn.open_table(TYPES_TABLE)?;
match types_table.get(key)? {
Some(type_val) if type_val.value() == "string" => {
let strings_table = read_txn.open_table(STRINGS_TABLE)?;
match strings_table.get(key)? {
Some(data) => {
let string_value: StringValue = bincode::deserialize(data.value())?;
match string_value.expires_at_ms {
Some(expires_at) => {
let now = now_in_millis();
if now > expires_at {
Ok(-2) // Key expired
} else {
Ok(((expires_at - now) / 1000) as i64) // TTL in seconds
}
}
None => Ok(-1), // No expiration
let expiration_table = read_txn.open_table(EXPIRATION_TABLE)?;
match expiration_table.get(key)? {
Some(expires_at) => {
let now = now_in_millis();
let expires_at = expires_at.value() as u128;
if now > expires_at {
Ok(-2) // Key expired
} else {
Ok(((expires_at - now) / 1000) as i64) // TTL in seconds
}
}
None => Ok(-2), // Key doesn't exist
None => Ok(-1), // No expiration
}
}
Some(_) => Ok(-1), // Other types don't have TTL implemented yet
Some(_) => Ok(-1), // Other types don't have TTL
None => Ok(-2), // Key doesn't exist
}
}
@@ -825,18 +845,13 @@ impl Storage {
let types_table = read_txn.open_table(TYPES_TABLE)?;
match types_table.get(key)? {
Some(_) => {
Some(type_val) => {
// For string types, check if not expired
if let Some(type_val) = types_table.get(key)? {
if type_val.value() == "string" {
let strings_table = read_txn.open_table(STRINGS_TABLE)?;
if let Some(data) = strings_table.get(key)? {
let string_value: StringValue = bincode::deserialize(data.value())?;
if let Some(expires_at) = string_value.expires_at_ms {
if now_in_millis() > expires_at {
return Ok(false); // Expired
}
}
if type_val.value() == "string" {
let expiration_table = read_txn.open_table(EXPIRATION_TABLE)?;
if let Some(expires_at) = expiration_table.get(key)? {
if now_in_millis() > expires_at.value() as u128 {
return Ok(false); // Expired
}
}
}
@@ -878,7 +893,7 @@ impl Storage {
None => ListValue { elements: Vec::new() },
};
for element in elements.into_iter().rev() {
for element in elements.into_iter() {
list_value.elements.insert(0, element);
}
new_len = list_value.elements.len() as u64;

View File

@@ -25,7 +25,7 @@ async fn debug_hset_simple() {
dir: test_dir.to_string(),
port,
debug: false,
max_databases: Some(16),
encrypt: false,
encryption_key: None,
};

View File

@@ -16,7 +16,7 @@ async fn debug_hset_return_value() {
dir: test_dir.to_string(),
port: 16390,
debug: false,
max_databases: Some(16),
encrypt: false,
encryption_key: None,
};

View File

@@ -120,9 +120,7 @@ async fn all_tests() {
test_transaction_operations(&mut conn).await;
test_discard_transaction(&mut conn).await;
test_type_command(&mut conn).await;
test_config_commands(&mut conn).await;
test_info_command(&mut conn).await;
test_error_handling(&mut conn).await;
}
async fn test_basic_ping(conn: &mut Connection) {
@@ -308,23 +306,6 @@ async fn test_type_command(conn: &mut Connection) {
cleanup_keys(conn).await;
}
async fn test_config_commands(conn: &mut Connection) {
cleanup_keys(conn).await;
let result: Vec<String> = redis::cmd("CONFIG")
.arg("GET")
.arg("databases")
.query(conn)
.unwrap();
assert_eq!(result, vec!["databases", "16"]);
let result: Vec<String> = redis::cmd("CONFIG")
.arg("GET")
.arg("dir")
.query(conn)
.unwrap();
assert_eq!(result[0], "dir");
assert!(result[1].contains("/tmp/herodb_test_"));
cleanup_keys(conn).await;
}
async fn test_info_command(conn: &mut Connection) {
cleanup_keys(conn).await;
@@ -334,17 +315,3 @@ async fn test_info_command(conn: &mut Connection) {
assert!(result.contains("role:master"));
cleanup_keys(conn).await;
}
async fn test_error_handling(conn: &mut Connection) {
cleanup_keys(conn).await;
let _: () = conn.set("string", "value").unwrap();
let result: RedisResult<String> = conn.hget("string", "field");
assert!(result.is_err());
let result: RedisResult<String> = redis::cmd("UNKNOWN").query(conn);
assert!(result.is_err());
let result: RedisResult<Vec<String>> = redis::cmd("EXEC").query(conn);
assert!(result.is_err());
let result: RedisResult<()> = redis::cmd("DISCARD").query(conn);
assert!(result.is_err());
cleanup_keys(conn).await;
}

View File

@@ -20,7 +20,7 @@ async fn start_test_server(test_name: &str) -> (Server, u16) {
dir: test_dir,
port,
debug: true,
max_databases: Some(16),
encrypt: false,
encryption_key: None,
};
@@ -581,22 +581,19 @@ async fn test_list_operations() {
// Test LRANGE
let response = send_command(&mut stream, "*4\r\n$6\r\nLRANGE\r\n$4\r\nlist\r\n$1\r\n0\r\n$2\r\n-1\r\n").await;
assert!(response.contains("b"));
assert!(response.contains("a"));
assert!(response.contains("c"));
assert!(response.contains("d"));
assert_eq!(response, "*4\r\n$1\r\nb\r\n$1\r\na\r\n$1\r\nc\r\n$1\r\nd\r\n");
// Test LINDEX
let response = send_command(&mut stream, "*3\r\n$6\r\nLINDEX\r\n$4\r\nlist\r\n$1\r\n0\r\n").await;
assert!(response.contains("b"));
assert_eq!(response, "$1\r\nb\r\n");
// Test LPOP
let response = send_command(&mut stream, "*2\r\n$4\r\nLPOP\r\n$4\r\nlist\r\n").await;
assert!(response.contains("b"));
assert_eq!(response, "$1\r\nb\r\n");
// Test RPOP
let response = send_command(&mut stream, "*2\r\n$4\r\nRPOP\r\n$4\r\nlist\r\n").await;
assert!(response.contains("d"));
assert_eq!(response, "$1\r\nd\r\n");
// Test LREM
send_command(&mut stream, "*3\r\n$5\r\nLPUSH\r\n$4\r\nlist\r\n$1\r\na\r\n").await; // list is now a, c, a

View File

@@ -22,7 +22,7 @@ async fn start_test_server(test_name: &str) -> (Server, u16) {
dir: test_dir,
port,
debug: true,
max_databases: Some(16),
encrypt: false,
encryption_key: None,
};

View File

@@ -20,7 +20,8 @@ async fn start_test_server(test_name: &str) -> (Server, u16) {
dir: test_dir,
port,
debug: false,
databases: 16,
encrypt: false,
encryption_key: None,
};
let server = Server::new(option).await;