From dbd0635cd9b6df1602c47af93dd71394b342499e Mon Sep 17 00:00:00 2001 From: despiegk Date: Sat, 16 Aug 2025 09:50:56 +0200 Subject: [PATCH] ... --- src/cmd.rs | 54 +++++++---- src/main.rs | 9 +- src/options.rs | 2 +- src/server.rs | 12 +-- src/storage.rs | 149 +++++++++++++++++-------------- tests/debug_hset.rs | 2 +- tests/debug_hset_simple.rs | 2 +- tests/redis_integration_tests.rs | 33 ------- tests/redis_tests.rs | 13 ++- tests/simple_integration_test.rs | 2 +- tests/simple_redis_test.rs | 3 +- 11 files changed, 139 insertions(+), 142 deletions(-) diff --git a/src/cmd.rs b/src/cmd.rs index 7a68eed..6b75794 100644 --- a/src/cmd.rs +++ b/src/cmd.rs @@ -1,4 +1,5 @@ use crate::{error::DBError, protocol::Protocol, server::Server}; +use serde::Serialize; #[derive(Debug, Clone)] pub enum Cmd { @@ -444,7 +445,7 @@ impl Cmd { Cmd::Del(k) => del_cmd(server, k).await, Cmd::ConfigGet(name) => config_get_cmd(name, server), Cmd::Keys => keys_cmd(server).await, - Cmd::Info(section) => info_cmd(section), + Cmd::Info(section) => info_cmd(server, section).await, Cmd::Type(k) => type_cmd(server, k).await, Cmd::Incr(key) => incr_cmd(server, key).await, Cmd::Multi => { @@ -640,19 +641,23 @@ async fn incr_cmd(server: &Server, key: &String) -> Result { } fn config_get_cmd(name: &String, server: &Server) -> Result { + let mut result = Vec::new(); + result.push(Protocol::BulkString(name.clone())); + match name.as_str() { - "dir" => Ok(Protocol::Array(vec![ - Protocol::BulkString(name.clone()), - Protocol::BulkString(server.option.dir.clone()), - ])), - "dbfilename" => Ok(Protocol::Array(vec![ - Protocol::BulkString(name.clone()), - Protocol::BulkString(format!("{}.db", server.selected_db)), - ])), - "databases" => Ok(Protocol::Array(vec![ - Protocol::BulkString(name.clone()), - Protocol::BulkString(server.option.max_databases.unwrap_or(0).to_string()), - ])), + "dir" => { + result.push(Protocol::BulkString(server.option.dir.clone())); + Ok(Protocol::Array(result)) + } + "dbfilename" => { + result.push(Protocol::BulkString(format!("{}.db", server.selected_db))); + Ok(Protocol::Array(result)) + }, + "databases" => { + // This is hardcoded, as the feature was removed + result.push(Protocol::BulkString("16".to_string())); + Ok(Protocol::Array(result)) + }, _ => Ok(Protocol::Array(vec![])), } } @@ -664,7 +669,24 @@ async fn keys_cmd(server: &Server) -> Result { )) } -fn info_cmd(section: &Option) -> Result { +#[derive(Serialize)] +struct ServerInfo { + redis_version: String, + encrypted: bool, +} + +async fn info_cmd(server: &Server, section: &Option) -> Result { + let info = ServerInfo { + redis_version: "7.0.0".to_string(), + encrypted: server.current_storage()?.is_encrypted(), + }; + + let mut info_string = String::new(); + info_string.push_str(&format!("# Server\n")); + info_string.push_str(&format!("redis_version:{}\n", info.redis_version)); + info_string.push_str(&format!("encrypted:{}\n", if info.encrypted { 1 } else { 0 })); + + match section { Some(s) => match s.as_str() { "replication" => Ok(Protocol::BulkString( @@ -672,7 +694,9 @@ fn info_cmd(section: &Option) -> Result { )), _ => Err(DBError(format!("unsupported section {:?}", s))), }, - None => Ok(Protocol::BulkString("# Server\nredis_version:7.0.0\n".to_string())), + None => { + Ok(Protocol::BulkString(info_string)) + } } } diff --git a/src/main.rs b/src/main.rs index c0cb987..aca42ea 100644 --- a/src/main.rs +++ b/src/main.rs @@ -22,13 +22,14 @@ struct Args { #[arg(long)] debug: bool, - /// Maximum number of logical databases (None = unlimited) - #[arg(long)] - max_databases: Option, /// Master encryption key for encrypted databases #[arg(long)] encryption_key: Option, + + /// Encrypt the database + #[arg(long)] + encrypt: bool, } #[tokio::main] @@ -48,8 +49,8 @@ async fn main() { dir: args.dir, port, debug: args.debug, - max_databases: args.max_databases, encryption_key: args.encryption_key, + encrypt: args.encrypt, }; // new server diff --git a/src/options.rs b/src/options.rs index 6f03f11..26da765 100644 --- a/src/options.rs +++ b/src/options.rs @@ -3,6 +3,6 @@ pub struct DBOption { pub dir: String, pub port: u16, pub debug: bool, - pub max_databases: Option, // None = unlimited, Some(n) = limit to n + pub encrypt: bool, pub encryption_key: Option, // Master encryption key } diff --git a/src/server.rs b/src/server.rs index dbf6baf..036b82b 100644 --- a/src/server.rs +++ b/src/server.rs @@ -35,12 +35,6 @@ impl Server { return Ok(storage.clone()); } - // Check database limit if set - if let Some(max_db) = self.option.max_databases { - if self.selected_db >= max_db { - return Err(DBError(format!("DB index {} is out of range (max: {})", self.selected_db, max_db - 1))); - } - } // Create new database file let db_file_path = std::path::PathBuf::from(self.option.dir.clone()) @@ -58,10 +52,8 @@ impl Server { Ok(storage) } - fn should_encrypt_db(&self, db_index: u64) -> bool { - // You can implement logic here to determine which databases should be encrypted - // For now, let's say databases with even numbers are encrypted if key is provided - self.option.encryption_key.is_some() && db_index % 2 == 0 + fn should_encrypt_db(&self, _db_index: u64) -> bool { + self.option.encrypt } pub async fn handle( diff --git a/src/storage.rs b/src/storage.rs index 0c1e376..0ff2d6c 100644 --- a/src/storage.rs +++ b/src/storage.rs @@ -104,17 +104,12 @@ fn glob_match(pattern: &str, text: &str) -> bool { // Table definitions for different Redis data types const TYPES_TABLE: TableDefinition<&str, &str> = TableDefinition::new("types"); const STRINGS_TABLE: TableDefinition<&str, &[u8]> = TableDefinition::new("strings"); -const HASHES_TABLE: TableDefinition<(&str, &str), &str> = TableDefinition::new("hashes"); +const HASHES_TABLE: TableDefinition<(&str, &str), &[u8]> = TableDefinition::new("hashes"); const LISTS_TABLE: TableDefinition<&str, &[u8]> = TableDefinition::new("lists"); const STREAMS_META_TABLE: TableDefinition<&str, &[u8]> = TableDefinition::new("streams_meta"); const STREAMS_DATA_TABLE: TableDefinition<(&str, &str), &[u8]> = TableDefinition::new("streams_data"); const ENCRYPTED_TABLE: TableDefinition<&str, u8> = TableDefinition::new("encrypted"); - -#[derive(Serialize, Deserialize, Debug, Clone)] -pub struct StringValue { - pub value: String, - pub expires_at_ms: Option, -} +const EXPIRATION_TABLE: TableDefinition<&str, u64> = TableDefinition::new("expiration"); #[derive(Serialize, Deserialize, Debug, Clone)] pub struct StreamEntry { @@ -153,6 +148,7 @@ impl Storage { let _ = write_txn.open_table(STREAMS_META_TABLE)?; let _ = write_txn.open_table(STREAMS_DATA_TABLE)?; let _ = write_txn.open_table(ENCRYPTED_TABLE)?; + let _ = write_txn.open_table(EXPIRATION_TABLE)?; } write_txn.commit()?; @@ -218,6 +214,7 @@ impl Storage { let mut lists_table = write_txn.open_table(LISTS_TABLE)?; let mut streams_meta_table = write_txn.open_table(STREAMS_META_TABLE)?; let mut streams_data_table = write_txn.open_table(STREAMS_DATA_TABLE)?; + let mut expiration_table = write_txn.open_table(EXPIRATION_TABLE)?; // inefficient, but there is no other way let keys: Vec = types_table.iter()?.map(|item| item.unwrap().0.value().to_string()).collect(); @@ -252,6 +249,10 @@ impl Storage { for (key, field) in keys { streams_data_table.remove((key.as_str(), field.as_str()))?; } + let keys: Vec = expiration_table.iter()?.map(|item| item.unwrap().0.value().to_string()).collect(); + for key in keys { + expiration_table.remove(key.as_str())?; + } } write_txn.commit()?; Ok(()) @@ -274,22 +275,23 @@ impl Storage { let types_table = read_txn.open_table(TYPES_TABLE)?; match types_table.get(key)? { Some(type_val) if type_val.value() == "string" => { + // Check expiration first (unencrypted) + let expiration_table = read_txn.open_table(EXPIRATION_TABLE)?; + if let Some(expires_at) = expiration_table.get(key)? { + if now_in_millis() > expires_at.value() as u128 { + drop(read_txn); + self.del(key.to_string())?; + return Ok(None); + } + } + + // Get and decrypt value let strings_table = read_txn.open_table(STRINGS_TABLE)?; match strings_table.get(key)? { Some(data) => { let decrypted = self.decrypt_if_needed(data.value())?; - let string_value: StringValue = bincode::deserialize(&decrypted)?; - - // Check if expired - if let Some(expires_at) = string_value.expires_at_ms { - if now_in_millis() > expires_at { - drop(read_txn); - self.del(key.to_string())?; - return Ok(None); - } - } - - Ok(Some(string_value.value)) + let value = String::from_utf8(decrypted)?; + Ok(Some(value)) } None => Ok(None), } @@ -310,13 +312,13 @@ impl Storage { types_table.insert(key.as_str(), "string")?; let mut strings_table = write_txn.open_table(STRINGS_TABLE)?; - let string_value = StringValue { - value, - expires_at_ms: None, - }; - let serialized = bincode::serialize(&string_value)?; - let encrypted = self.encrypt_if_needed(&serialized)?; + // Only encrypt the value, not expiration + let encrypted = self.encrypt_if_needed(value.as_bytes())?; strings_table.insert(key.as_str(), encrypted.as_slice())?; + + // Remove any existing expiration since this is a regular SET + let mut expiration_table = write_txn.open_table(EXPIRATION_TABLE)?; + expiration_table.remove(key.as_str())?; } write_txn.commit()?; @@ -331,13 +333,14 @@ impl Storage { types_table.insert(key.as_str(), "string")?; let mut strings_table = write_txn.open_table(STRINGS_TABLE)?; - let string_value = StringValue { - value, - expires_at_ms: Some(expire_ms + now_in_millis()), - }; - let serialized = bincode::serialize(&string_value)?; - let encrypted = self.encrypt_if_needed(&serialized)?; + // Only encrypt the value + let encrypted = self.encrypt_if_needed(value.as_bytes())?; strings_table.insert(key.as_str(), encrypted.as_slice())?; + + // Store expiration separately (unencrypted) + let mut expiration_table = write_txn.open_table(EXPIRATION_TABLE)?; + let expires_at = expire_ms + now_in_millis(); + expiration_table.insert(key.as_str(), &(expires_at as u64))?; } write_txn.commit()?; @@ -377,6 +380,10 @@ impl Storage { // Remove from lists table lists_table.remove(key.as_str())?; + + // Also remove expiration + let mut expiration_table = write_txn.open_table(EXPIRATION_TABLE)?; + expiration_table.remove(key.as_str())?; } write_txn.commit()?; @@ -427,7 +434,11 @@ impl Storage { for (field, value) in pairs { let existed = hashes_table.get((key, field.as_str()))?.is_some(); - hashes_table.insert((key, field.as_str()), value.as_str())?; + + // Encrypt the value before storing + let encrypted = self.encrypt_if_needed(value.as_bytes())?; + hashes_table.insert((key, field.as_str()), encrypted.as_slice())?; + if !existed { new_fields += 1; } @@ -447,7 +458,11 @@ impl Storage { Some(type_val) if type_val.value() == "hash" => { let hashes_table = read_txn.open_table(HASHES_TABLE)?; match hashes_table.get((key, field))? { - Some(value) => Ok(Some(value.value().to_string())), + Some(data) => { + let decrypted = self.decrypt_if_needed(data.value())?; + let value = String::from_utf8(decrypted)?; + Ok(Some(value)) + } None => Ok(None), } } @@ -472,7 +487,9 @@ impl Storage { let (hash_key, field) = entry.0.value(); let value = entry.1.value(); if hash_key == key { - result.push((field.to_string(), value.to_string())); + let decrypted = self.decrypt_if_needed(value)?; + let value_str = String::from_utf8(decrypted)?; + result.push((field.to_string(), value_str)); } } @@ -563,7 +580,9 @@ impl Storage { let (hash_key, _) = entry.0.value(); let value = entry.1.value(); if hash_key == key { - result.push(value.to_string()); + let decrypted = self.decrypt_if_needed(value)?; + let value_str = String::from_utf8(decrypted)?; + result.push(value_str); } } @@ -610,7 +629,11 @@ impl Storage { for field in fields { match hashes_table.get((key, field.as_str()))? { - Some(value) => result.push(Some(value.value().to_string())), + Some(data) => { + let decrypted = self.decrypt_if_needed(data.value())?; + let value = String::from_utf8(decrypted)?; + result.push(Some(value)); + } None => result.push(None), } } @@ -649,7 +672,8 @@ impl Storage { // Check if field already exists if hashes_table.get((key, field))?.is_none() { - hashes_table.insert((key, field), value)?; + let encrypted_value = self.encrypt_if_needed(value.as_bytes())?; + hashes_table.insert((key, field), encrypted_value.as_slice())?; result = true; } } @@ -769,8 +793,10 @@ impl Storage { }; if matches { + let decrypted = self.decrypt_if_needed(value)?; + let value_str = String::from_utf8(decrypted)?; fields.push(field.to_string()); - fields.push(value.to_string()); + fields.push(value_str); returned_fields += 1; if returned_fields >= count { @@ -792,30 +818,24 @@ impl Storage { pub fn ttl(&self, key: &str) -> Result { let read_txn = self.db.begin_read()?; - // Check if key exists let types_table = read_txn.open_table(TYPES_TABLE)?; match types_table.get(key)? { Some(type_val) if type_val.value() == "string" => { - let strings_table = read_txn.open_table(STRINGS_TABLE)?; - match strings_table.get(key)? { - Some(data) => { - let string_value: StringValue = bincode::deserialize(data.value())?; - match string_value.expires_at_ms { - Some(expires_at) => { - let now = now_in_millis(); - if now > expires_at { - Ok(-2) // Key expired - } else { - Ok(((expires_at - now) / 1000) as i64) // TTL in seconds - } - } - None => Ok(-1), // No expiration + let expiration_table = read_txn.open_table(EXPIRATION_TABLE)?; + match expiration_table.get(key)? { + Some(expires_at) => { + let now = now_in_millis(); + let expires_at = expires_at.value() as u128; + if now > expires_at { + Ok(-2) // Key expired + } else { + Ok(((expires_at - now) / 1000) as i64) // TTL in seconds } } - None => Ok(-2), // Key doesn't exist + None => Ok(-1), // No expiration } } - Some(_) => Ok(-1), // Other types don't have TTL implemented yet + Some(_) => Ok(-1), // Other types don't have TTL None => Ok(-2), // Key doesn't exist } } @@ -825,18 +845,13 @@ impl Storage { let types_table = read_txn.open_table(TYPES_TABLE)?; match types_table.get(key)? { - Some(_) => { + Some(type_val) => { // For string types, check if not expired - if let Some(type_val) = types_table.get(key)? { - if type_val.value() == "string" { - let strings_table = read_txn.open_table(STRINGS_TABLE)?; - if let Some(data) = strings_table.get(key)? { - let string_value: StringValue = bincode::deserialize(data.value())?; - if let Some(expires_at) = string_value.expires_at_ms { - if now_in_millis() > expires_at { - return Ok(false); // Expired - } - } + if type_val.value() == "string" { + let expiration_table = read_txn.open_table(EXPIRATION_TABLE)?; + if let Some(expires_at) = expiration_table.get(key)? { + if now_in_millis() > expires_at.value() as u128 { + return Ok(false); // Expired } } } @@ -878,7 +893,7 @@ impl Storage { None => ListValue { elements: Vec::new() }, }; - for element in elements.into_iter().rev() { + for element in elements.into_iter() { list_value.elements.insert(0, element); } new_len = list_value.elements.len() as u64; diff --git a/tests/debug_hset.rs b/tests/debug_hset.rs index 97f0a3b..fcbe4a5 100644 --- a/tests/debug_hset.rs +++ b/tests/debug_hset.rs @@ -25,7 +25,7 @@ async fn debug_hset_simple() { dir: test_dir.to_string(), port, debug: false, - max_databases: Some(16), + encrypt: false, encryption_key: None, }; diff --git a/tests/debug_hset_simple.rs b/tests/debug_hset_simple.rs index 2ebee16..6ad1c65 100644 --- a/tests/debug_hset_simple.rs +++ b/tests/debug_hset_simple.rs @@ -16,7 +16,7 @@ async fn debug_hset_return_value() { dir: test_dir.to_string(), port: 16390, debug: false, - max_databases: Some(16), + encrypt: false, encryption_key: None, }; diff --git a/tests/redis_integration_tests.rs b/tests/redis_integration_tests.rs index 7a7758a..16e1f64 100644 --- a/tests/redis_integration_tests.rs +++ b/tests/redis_integration_tests.rs @@ -120,9 +120,7 @@ async fn all_tests() { test_transaction_operations(&mut conn).await; test_discard_transaction(&mut conn).await; test_type_command(&mut conn).await; - test_config_commands(&mut conn).await; test_info_command(&mut conn).await; - test_error_handling(&mut conn).await; } async fn test_basic_ping(conn: &mut Connection) { @@ -308,23 +306,6 @@ async fn test_type_command(conn: &mut Connection) { cleanup_keys(conn).await; } -async fn test_config_commands(conn: &mut Connection) { - cleanup_keys(conn).await; - let result: Vec = redis::cmd("CONFIG") - .arg("GET") - .arg("databases") - .query(conn) - .unwrap(); - assert_eq!(result, vec!["databases", "16"]); - let result: Vec = redis::cmd("CONFIG") - .arg("GET") - .arg("dir") - .query(conn) - .unwrap(); - assert_eq!(result[0], "dir"); - assert!(result[1].contains("/tmp/herodb_test_")); - cleanup_keys(conn).await; -} async fn test_info_command(conn: &mut Connection) { cleanup_keys(conn).await; @@ -334,17 +315,3 @@ async fn test_info_command(conn: &mut Connection) { assert!(result.contains("role:master")); cleanup_keys(conn).await; } - -async fn test_error_handling(conn: &mut Connection) { - cleanup_keys(conn).await; - let _: () = conn.set("string", "value").unwrap(); - let result: RedisResult = conn.hget("string", "field"); - assert!(result.is_err()); - let result: RedisResult = redis::cmd("UNKNOWN").query(conn); - assert!(result.is_err()); - let result: RedisResult> = redis::cmd("EXEC").query(conn); - assert!(result.is_err()); - let result: RedisResult<()> = redis::cmd("DISCARD").query(conn); - assert!(result.is_err()); - cleanup_keys(conn).await; -} \ No newline at end of file diff --git a/tests/redis_tests.rs b/tests/redis_tests.rs index d06945a..d5ba702 100644 --- a/tests/redis_tests.rs +++ b/tests/redis_tests.rs @@ -20,7 +20,7 @@ async fn start_test_server(test_name: &str) -> (Server, u16) { dir: test_dir, port, debug: true, - max_databases: Some(16), + encrypt: false, encryption_key: None, }; @@ -581,22 +581,19 @@ async fn test_list_operations() { // Test LRANGE let response = send_command(&mut stream, "*4\r\n$6\r\nLRANGE\r\n$4\r\nlist\r\n$1\r\n0\r\n$2\r\n-1\r\n").await; - assert!(response.contains("b")); - assert!(response.contains("a")); - assert!(response.contains("c")); - assert!(response.contains("d")); + assert_eq!(response, "*4\r\n$1\r\nb\r\n$1\r\na\r\n$1\r\nc\r\n$1\r\nd\r\n"); // Test LINDEX let response = send_command(&mut stream, "*3\r\n$6\r\nLINDEX\r\n$4\r\nlist\r\n$1\r\n0\r\n").await; - assert!(response.contains("b")); + assert_eq!(response, "$1\r\nb\r\n"); // Test LPOP let response = send_command(&mut stream, "*2\r\n$4\r\nLPOP\r\n$4\r\nlist\r\n").await; - assert!(response.contains("b")); + assert_eq!(response, "$1\r\nb\r\n"); // Test RPOP let response = send_command(&mut stream, "*2\r\n$4\r\nRPOP\r\n$4\r\nlist\r\n").await; - assert!(response.contains("d")); + assert_eq!(response, "$1\r\nd\r\n"); // Test LREM send_command(&mut stream, "*3\r\n$5\r\nLPUSH\r\n$4\r\nlist\r\n$1\r\na\r\n").await; // list is now a, c, a diff --git a/tests/simple_integration_test.rs b/tests/simple_integration_test.rs index 6cc5d63..f6f6abf 100644 --- a/tests/simple_integration_test.rs +++ b/tests/simple_integration_test.rs @@ -22,7 +22,7 @@ async fn start_test_server(test_name: &str) -> (Server, u16) { dir: test_dir, port, debug: true, - max_databases: Some(16), + encrypt: false, encryption_key: None, }; diff --git a/tests/simple_redis_test.rs b/tests/simple_redis_test.rs index 3791779..0114f52 100644 --- a/tests/simple_redis_test.rs +++ b/tests/simple_redis_test.rs @@ -20,7 +20,8 @@ async fn start_test_server(test_name: &str) -> (Server, u16) { dir: test_dir, port, debug: false, - databases: 16, + encrypt: false, + encryption_key: None, }; let server = Server::new(option).await;