From ee94d731d7f4857956bf04fe151aa43ff491d6c8 Mon Sep 17 00:00:00 2001 From: despiegk Date: Sat, 16 Aug 2025 11:09:18 +0200 Subject: [PATCH] ... --- Cargo.lock | 13 + Cargo.toml | 1 + src/cmd.rs | 62 +- src/error.rs | 12 + src/server.rs | 5 +- src/storage.rs | 1261 ---------------------------------- src/storage/mod.rs | 126 ++++ src/storage/storage_basic.rs | 218 ++++++ src/storage/storage_extra.rs | 168 +++++ src/storage/storage_hset.rs | 318 +++++++++ src/storage/storage_lists.rs | 403 +++++++++++ 11 files changed, 1297 insertions(+), 1290 deletions(-) delete mode 100644 src/storage.rs create mode 100644 src/storage/mod.rs create mode 100644 src/storage/storage_basic.rs create mode 100644 src/storage/storage_extra.rs create mode 100644 src/storage/storage_hset.rs create mode 100644 src/storage/storage_lists.rs diff --git a/Cargo.lock b/Cargo.lock index f7bfb53..893fe2d 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -799,6 +799,7 @@ dependencies = [ "redb", "redis", "serde", + "serde_json", "sha2", "thiserror", "tokio", @@ -851,6 +852,18 @@ dependencies = [ "syn", ] +[[package]] +name = "serde_json" +version = "1.0.142" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "030fedb782600dcbd6f02d479bf0d817ac3bb40d644745b769d6a96bc3afc5a7" +dependencies = [ + "itoa", + "memchr", + "ryu", + "serde", +] + [[package]] name = "sha1_smol" version = "1.0.1" diff --git a/Cargo.toml b/Cargo.toml index 18dc7ab..4b5e15f 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -14,6 +14,7 @@ byteorder = "1.4.3" futures = "0.3" redb = "2.1.3" serde = { version = "1.0", features = ["derive"] } +serde_json = "1.0" bincode = "1.3.3" chacha20poly1305 = "0.10.1" rand = "0.8" diff --git a/src/cmd.rs b/src/cmd.rs index fbb844f..a230f98 100644 --- a/src/cmd.rs +++ b/src/cmd.rs @@ -553,41 +553,41 @@ async fn llen_cmd(server: &Server, key: &str) -> Result { } async fn lpop_cmd(server: &Server, key: &str, count: &Option) -> Result { - match server.current_storage()?.lpop(key, *count) { - Ok(Some(elements)) => { - if count.is_some() { + let count_val = count.unwrap_or(1); + match server.current_storage()?.lpop(key, count_val) { + Ok(elements) => { + if elements.is_empty() { + if count.is_some() { + Ok(Protocol::Array(vec![])) + } else { + Ok(Protocol::Null) + } + } else if count.is_some() { Ok(Protocol::Array(elements.into_iter().map(Protocol::BulkString).collect())) } else { Ok(Protocol::BulkString(elements[0].clone())) } }, - Ok(None) => { - if count.is_some() { - Ok(Protocol::Array(vec![])) - } else { - Ok(Protocol::Null) - } - }, Err(e) => Ok(Protocol::err(&e.0)), } } async fn rpop_cmd(server: &Server, key: &str, count: &Option) -> Result { - match server.current_storage()?.rpop(key, *count) { - Ok(Some(elements)) => { - if count.is_some() { + let count_val = count.unwrap_or(1); + match server.current_storage()?.rpop(key, count_val) { + Ok(elements) => { + if elements.is_empty() { + if count.is_some() { + Ok(Protocol::Array(vec![])) + } else { + Ok(Protocol::Null) + } + } else if count.is_some() { Ok(Protocol::Array(elements.into_iter().map(Protocol::BulkString).collect())) } else { Ok(Protocol::BulkString(elements[0].clone())) } }, - Ok(None) => { - if count.is_some() { - Ok(Protocol::Array(vec![])) - } else { - Ok(Protocol::Null) - } - }, Err(e) => Ok(Protocol::err(&e.0)), } } @@ -746,7 +746,7 @@ async fn get_cmd(server: &Server, k: &str) -> Result { // Hash command implementations async fn hset_cmd(server: &Server, key: &str, pairs: &[(String, String)]) -> Result { - let new_fields = server.current_storage()?.hset(key, pairs)?; + let new_fields = server.current_storage()?.hset(key, pairs.to_vec())?; Ok(Protocol::SimpleString(new_fields.to_string())) } @@ -773,7 +773,7 @@ async fn hgetall_cmd(server: &Server, key: &str) -> Result { } async fn hdel_cmd(server: &Server, key: &str, fields: &[String]) -> Result { - match server.current_storage()?.hdel(key, fields) { + match server.current_storage()?.hdel(key, fields.to_vec()) { Ok(deleted) => Ok(Protocol::SimpleString(deleted.to_string())), Err(e) => Ok(Protocol::err(&e.0)), } @@ -812,7 +812,7 @@ async fn hlen_cmd(server: &Server, key: &str) -> Result { } async fn hmget_cmd(server: &Server, key: &str, fields: &[String]) -> Result { - match server.current_storage()?.hmget(key, fields) { + match server.current_storage()?.hmget(key, fields.to_vec()) { Ok(values) => { let result: Vec = values .into_iter() @@ -838,10 +838,12 @@ async fn scan_cmd( count: &Option ) -> Result { match server.current_storage()?.scan(*cursor, pattern, *count) { - Ok((next_cursor, keys)) => { + Ok((next_cursor, key_value_pairs)) => { let mut result = Vec::new(); result.push(Protocol::BulkString(next_cursor.to_string())); - result.push(Protocol::Array(keys.into_iter().map(Protocol::BulkString).collect())); + // For SCAN, we only return the keys, not the values + let keys: Vec = key_value_pairs.into_iter().map(|(key, _)| Protocol::BulkString(key)).collect(); + result.push(Protocol::Array(keys)); Ok(Protocol::Array(result)) } Err(e) => Ok(Protocol::err(&format!("ERR {}", e.0))), @@ -856,10 +858,16 @@ async fn hscan_cmd( count: &Option ) -> Result { match server.current_storage()?.hscan(key, *cursor, pattern, *count) { - Ok((next_cursor, fields)) => { + Ok((next_cursor, field_value_pairs)) => { let mut result = Vec::new(); result.push(Protocol::BulkString(next_cursor.to_string())); - result.push(Protocol::Array(fields.into_iter().map(Protocol::BulkString).collect())); + // For HSCAN, we return field-value pairs flattened + let mut fields_and_values = Vec::new(); + for (field, value) in field_value_pairs { + fields_and_values.push(Protocol::BulkString(field)); + fields_and_values.push(Protocol::BulkString(value)); + } + result.push(Protocol::Array(fields_and_values)); Ok(Protocol::Array(result)) } Err(e) => Ok(Protocol::err(&format!("ERR {}", e.0))), diff --git a/src/error.rs b/src/error.rs index 018e878..3037c70 100644 --- a/src/error.rs +++ b/src/error.rs @@ -80,3 +80,15 @@ impl From> for DBError { DBError(item.to_string().clone()) } } + +impl From for DBError { + fn from(item: serde_json::Error) -> Self { + DBError(item.to_string()) + } +} + +impl From for DBError { + fn from(item: chacha20poly1305::Error) -> Self { + DBError(item.to_string()) + } +} diff --git a/src/server.rs b/src/server.rs index 80f64a1..c286e21 100644 --- a/src/server.rs +++ b/src/server.rs @@ -61,8 +61,9 @@ impl Server { Ok(storage) } - fn should_encrypt_db(&self, _db_index: u64) -> bool { - self.option.encrypt + fn should_encrypt_db(&self, db_index: u64) -> bool { + // DB 0-9 are non-encrypted, DB 10+ are encrypted + self.option.encrypt && db_index >= 10 } pub async fn handle( diff --git a/src/storage.rs b/src/storage.rs deleted file mode 100644 index 083c9f5..0000000 --- a/src/storage.rs +++ /dev/null @@ -1,1261 +0,0 @@ -use std::{ - path::Path, - time::{SystemTime, UNIX_EPOCH}, -}; - -use redb::{Database, ReadableTable, TableDefinition}; -use serde::{Deserialize, Serialize}; - -use crate::crypto::CryptoFactory; -use crate::error::DBError; - -// Add this glob matching function -fn glob_match(pattern: &str, text: &str) -> bool { - fn match_recursive(pattern: &[char], text: &[char], p_idx: usize, t_idx: usize) -> bool { - if p_idx >= pattern.len() { - return t_idx >= text.len(); - } - - match pattern[p_idx] { - '*' => { - // Try matching zero characters - if match_recursive(pattern, text, p_idx + 1, t_idx) { - return true; - } - // Try matching one or more characters - for i in t_idx..text.len() { - if match_recursive(pattern, text, p_idx + 1, i + 1) { - return true; - } - } - false - } - '?' => { - if t_idx >= text.len() { - false - } else { - match_recursive(pattern, text, p_idx + 1, t_idx + 1) - } - } - '[' => { - // Find the closing bracket - let mut bracket_end = p_idx + 1; - while bracket_end < pattern.len() && pattern[bracket_end] != ']' { - bracket_end += 1; - } - if bracket_end >= pattern.len() || t_idx >= text.len() { - return false; - } - - let bracket_content = &pattern[p_idx + 1..bracket_end]; - let char_to_match = text[t_idx]; - let mut matched = false; - - let mut i = 0; - while i < bracket_content.len() { - if i + 2 < bracket_content.len() && bracket_content[i + 1] == '-' { - // Range like [a-z] - if char_to_match >= bracket_content[i] && char_to_match <= bracket_content[i + 2] { - matched = true; - break; - } - i += 3; - } else { - // Single character - if char_to_match == bracket_content[i] { - matched = true; - break; - } - i += 1; - } - } - - if matched { - match_recursive(pattern, text, bracket_end + 1, t_idx + 1) - } else { - false - } - } - '\\' => { - // Escape next character - if p_idx + 1 >= pattern.len() || t_idx >= text.len() { - false - } else if pattern[p_idx + 1] == text[t_idx] { - match_recursive(pattern, text, p_idx + 2, t_idx + 1) - } else { - false - } - } - c => { - if t_idx >= text.len() || c != text[t_idx] { - false - } else { - match_recursive(pattern, text, p_idx + 1, t_idx + 1) - } - } - } - } - - let pattern_chars: Vec = pattern.chars().collect(); - let text_chars: Vec = text.chars().collect(); - match_recursive(&pattern_chars, &text_chars, 0, 0) -} - -// Table definitions for different Redis data types -const TYPES_TABLE: TableDefinition<&str, &str> = TableDefinition::new("types"); -const STRINGS_TABLE: TableDefinition<&str, &[u8]> = TableDefinition::new("strings"); -const HASHES_TABLE: TableDefinition<(&str, &str), &[u8]> = TableDefinition::new("hashes"); -const LISTS_TABLE: TableDefinition<&str, &[u8]> = TableDefinition::new("lists"); -const STREAMS_META_TABLE: TableDefinition<&str, &[u8]> = TableDefinition::new("streams_meta"); -const STREAMS_DATA_TABLE: TableDefinition<(&str, &str), &[u8]> = TableDefinition::new("streams_data"); -const ENCRYPTED_TABLE: TableDefinition<&str, u8> = TableDefinition::new("encrypted"); -const EXPIRATION_TABLE: TableDefinition<&str, u64> = TableDefinition::new("expiration"); - -#[derive(Serialize, Deserialize, Debug, Clone)] -pub struct StreamEntry { - pub fields: Vec<(String, String)>, -} - -#[derive(Serialize, Deserialize, Debug, Clone)] -pub struct ListValue { - pub elements: Vec, -} - - -#[inline] -pub fn now_in_millis() -> u128 { - let start = SystemTime::now(); - let duration_since_epoch = start.duration_since(UNIX_EPOCH).unwrap(); - duration_since_epoch.as_millis() -} - -pub struct Storage { - db: Database, - crypto: Option, -} - -impl Storage { - pub fn new(path: impl AsRef, should_encrypt: bool, master_key: Option<&str>) -> Result { - let db = Database::create(path)?; - - // Create tables if they don't exist - let write_txn = db.begin_write()?; - { - let _ = write_txn.open_table(TYPES_TABLE)?; - let _ = write_txn.open_table(STRINGS_TABLE)?; - let _ = write_txn.open_table(HASHES_TABLE)?; - let _ = write_txn.open_table(LISTS_TABLE)?; - let _ = write_txn.open_table(STREAMS_META_TABLE)?; - let _ = write_txn.open_table(STREAMS_DATA_TABLE)?; - let _ = write_txn.open_table(ENCRYPTED_TABLE)?; - let _ = write_txn.open_table(EXPIRATION_TABLE)?; - } - write_txn.commit()?; - - // Check if database was previously encrypted - let read_txn = db.begin_read()?; - let encrypted_table = read_txn.open_table(ENCRYPTED_TABLE)?; - let was_encrypted = encrypted_table.get("encrypted")?.map(|v| v.value() == 1).unwrap_or(false); - drop(read_txn); - - let crypto = if should_encrypt || was_encrypted { - if let Some(key) = master_key { - Some(CryptoFactory::new(key.as_bytes())) - } else { - return Err(DBError("Encryption requested but no master key provided".to_string())); - } - } else { - None - }; - - // If we're enabling encryption for the first time, mark it - if should_encrypt && !was_encrypted { - let write_txn = db.begin_write()?; - { - let mut encrypted_table = write_txn.open_table(ENCRYPTED_TABLE)?; - encrypted_table.insert("encrypted", &1u8)?; - } - write_txn.commit()?; - } - - Ok(Storage { - db, - crypto, - }) - } - - pub fn is_encrypted(&self) -> bool { - self.crypto.is_some() - } - - // Helper methods for encryption - fn encrypt_if_needed(&self, data: &[u8]) -> Result, DBError> { - if let Some(crypto) = &self.crypto { - Ok(crypto.encrypt(data)) - } else { - Ok(data.to_vec()) - } - } - - fn decrypt_if_needed(&self, data: &[u8]) -> Result, DBError> { - if let Some(crypto) = &self.crypto { - Ok(crypto.decrypt(data)?) - } else { - Ok(data.to_vec()) - } - } - - pub fn flushdb(&self) -> Result<(), DBError> { - let write_txn = self.db.begin_write()?; - { - let mut types_table = write_txn.open_table(TYPES_TABLE)?; - let mut strings_table = write_txn.open_table(STRINGS_TABLE)?; - let mut hashes_table = write_txn.open_table(HASHES_TABLE)?; - let mut lists_table = write_txn.open_table(LISTS_TABLE)?; - let mut streams_meta_table = write_txn.open_table(STREAMS_META_TABLE)?; - let mut streams_data_table = write_txn.open_table(STREAMS_DATA_TABLE)?; - let mut expiration_table = write_txn.open_table(EXPIRATION_TABLE)?; - - // inefficient, but there is no other way - let keys: Vec = types_table.iter()?.map(|item| item.unwrap().0.value().to_string()).collect(); - for key in keys { - types_table.remove(key.as_str())?; - } - let keys: Vec = strings_table.iter()?.map(|item| item.unwrap().0.value().to_string()).collect(); - for key in keys { - strings_table.remove(key.as_str())?; - } - let keys: Vec<(String, String)> = hashes_table - .iter()? - .map(|item| { - let binding = item.unwrap(); - let (k, f) = binding.0.value(); - (k.to_string(), f.to_string()) - }) - .collect(); - for (key, field) in keys { - hashes_table.remove((key.as_str(), field.as_str()))?; - } - let keys: Vec = lists_table.iter()?.map(|item| item.unwrap().0.value().to_string()).collect(); - for key in keys { - lists_table.remove(key.as_str())?; - } - let keys: Vec = streams_meta_table.iter()?.map(|item| item.unwrap().0.value().to_string()).collect(); - for key in keys { - streams_meta_table.remove(key.as_str())?; - } - let keys: Vec<(String,String)> = streams_data_table.iter()?.map(|item| { - let binding = item.unwrap(); - let (key, field) = binding.0.value(); - (key.to_string(), field.to_string()) - }).collect(); - for (key, field) in keys { - streams_data_table.remove((key.as_str(), field.as_str()))?; - } - let keys: Vec = expiration_table.iter()?.map(|item| item.unwrap().0.value().to_string()).collect(); - for key in keys { - expiration_table.remove(key.as_str())?; - } - } - write_txn.commit()?; - Ok(()) - } - - pub fn get_key_type(&self, key: &str) -> Result, DBError> { - let read_txn = self.db.begin_read()?; - let table = read_txn.open_table(TYPES_TABLE)?; - - // Before returning type, check for expiration - if let Some(type_val) = table.get(key)? { - if type_val.value() == "string" { - let expiration_table = read_txn.open_table(EXPIRATION_TABLE)?; - if let Some(expires_at) = expiration_table.get(key)? { - if now_in_millis() > expires_at.value() as u128 { - // The key is expired, so it effectively has no type - return Ok(None); - } - } - } - Ok(Some(type_val.value().to_string())) - } else { - Ok(None) - } - } - - // Update the get method to use decryption - pub fn get(&self, key: &str) -> Result, DBError> { - let read_txn = self.db.begin_read()?; - - let types_table = read_txn.open_table(TYPES_TABLE)?; - match types_table.get(key)? { - Some(type_val) if type_val.value() == "string" => { - // Check expiration first (unencrypted) - let expiration_table = read_txn.open_table(EXPIRATION_TABLE)?; - if let Some(expires_at) = expiration_table.get(key)? { - if now_in_millis() > expires_at.value() as u128 { - drop(read_txn); - self.del(key.to_string())?; - return Ok(None); - } - } - - // Get and decrypt value - let strings_table = read_txn.open_table(STRINGS_TABLE)?; - match strings_table.get(key)? { - Some(data) => { - let decrypted = self.decrypt_if_needed(data.value())?; - let value = String::from_utf8(decrypted)?; - Ok(Some(value)) - } - None => Ok(None), - } - } - _ => Ok(None), - } - } - - // Apply similar encryption/decryption to other methods (setx, hset, lpush, etc.) - // ... (you'll need to update all methods that store/retrieve serialized data) - - // Update the set method to use encryption - pub fn set(&self, key: String, value: String) -> Result<(), DBError> { - let write_txn = self.db.begin_write()?; - - { - let mut types_table = write_txn.open_table(TYPES_TABLE)?; - types_table.insert(key.as_str(), "string")?; - - let mut strings_table = write_txn.open_table(STRINGS_TABLE)?; - // Only encrypt the value, not expiration - let encrypted = self.encrypt_if_needed(value.as_bytes())?; - strings_table.insert(key.as_str(), encrypted.as_slice())?; - - // Remove any existing expiration since this is a regular SET - let mut expiration_table = write_txn.open_table(EXPIRATION_TABLE)?; - expiration_table.remove(key.as_str())?; - } - - write_txn.commit()?; - Ok(()) - } - - pub fn setx(&self, key: String, value: String, expire_ms: u128) -> Result<(), DBError> { - let write_txn = self.db.begin_write()?; - - { - let mut types_table = write_txn.open_table(TYPES_TABLE)?; - types_table.insert(key.as_str(), "string")?; - - let mut strings_table = write_txn.open_table(STRINGS_TABLE)?; - // Only encrypt the value - let encrypted = self.encrypt_if_needed(value.as_bytes())?; - strings_table.insert(key.as_str(), encrypted.as_slice())?; - - // Store expiration separately (unencrypted) - let mut expiration_table = write_txn.open_table(EXPIRATION_TABLE)?; - let expires_at = expire_ms + now_in_millis(); - expiration_table.insert(key.as_str(), &(expires_at as u64))?; - } - - write_txn.commit()?; - Ok(()) - } - - pub fn del(&self, key: String) -> Result<(), DBError> { - let write_txn = self.db.begin_write()?; - - { - let mut types_table = write_txn.open_table(TYPES_TABLE)?; - let mut strings_table = write_txn.open_table(STRINGS_TABLE)?; - let mut hashes_table: redb::Table<(&str, &str), &[u8]> = write_txn.open_table(HASHES_TABLE)?; - let mut lists_table = write_txn.open_table(LISTS_TABLE)?; - - // Remove from type table - types_table.remove(key.as_str())?; - - // Remove from strings table - strings_table.remove(key.as_str())?; - - // Remove all hash fields for this key - let mut to_remove = Vec::new(); - let mut iter = hashes_table.iter()?; - while let Some(entry) = iter.next() { - let entry = entry?; - let (hash_key, field) = entry.0.value(); - if hash_key == key.as_str() { - to_remove.push((hash_key.to_string(), field.to_string())); - } - } - drop(iter); - - for (hash_key, field) in to_remove { - hashes_table.remove((hash_key.as_str(), field.as_str()))?; - } - - // Remove from lists table - lists_table.remove(key.as_str())?; - - // Also remove expiration - let mut expiration_table = write_txn.open_table(EXPIRATION_TABLE)?; - expiration_table.remove(key.as_str())?; - } - - write_txn.commit()?; - Ok(()) - } - - pub fn keys(&self, pattern: &str) -> Result, DBError> { - let read_txn = self.db.begin_read()?; - let table = read_txn.open_table(TYPES_TABLE)?; - - let mut keys = Vec::new(); - let mut iter = table.iter()?; - while let Some(entry) = iter.next() { - let key = entry?.0.value().to_string(); - if pattern == "*" || glob_match(pattern, &key) { - keys.push(key); - } - } - - Ok(keys) - } - - // Hash operations - pub fn hset(&self, key: &str, pairs: &[(String, String)]) -> Result { - let write_txn = self.db.begin_write()?; - let mut new_fields = 0u64; - - { - let mut types_table = write_txn.open_table(TYPES_TABLE)?; - let mut hashes_table = write_txn.open_table(HASHES_TABLE)?; - - // Check if key exists and is of correct type - let existing_type = match types_table.get(key)? { - Some(type_val) => Some(type_val.value().to_string()), - None => None, - }; - - match existing_type { - Some(ref type_str) if type_str != "hash" => { - return Err(DBError("WRONGTYPE Operation against a key holding the wrong kind of value".to_string())); - } - None => { - // Set type to hash - types_table.insert(key, "hash")?; - } - _ => {} - } - - for (field, value) in pairs { - let existed = hashes_table.get((key, field.as_str()))?.is_some(); - - // Encrypt the value before storing - let encrypted = self.encrypt_if_needed(value.as_bytes())?; - hashes_table.insert((key, field.as_str()), encrypted.as_slice())?; - - if !existed { - new_fields += 1; - } - } - } - - write_txn.commit()?; - Ok(new_fields) - } - - pub fn hget(&self, key: &str, field: &str) -> Result, DBError> { - let read_txn = self.db.begin_read()?; - - // Check type - let types_table = read_txn.open_table(TYPES_TABLE)?; - match types_table.get(key)? { - Some(type_val) if type_val.value() == "hash" => { - let hashes_table = read_txn.open_table(HASHES_TABLE)?; - match hashes_table.get((key, field))? { - Some(data) => { - let decrypted = self.decrypt_if_needed(data.value())?; - let value = String::from_utf8(decrypted)?; - Ok(Some(value)) - } - None => Ok(None), - } - } - Some(_) => return Err(DBError("WRONGTYPE Operation against a key holding the wrong kind of value".to_string())), - None => Ok(None), - } - } - - pub fn hgetall(&self, key: &str) -> Result, DBError> { - let read_txn = self.db.begin_read()?; - - // Check type - let types_table = read_txn.open_table(TYPES_TABLE)?; - match types_table.get(key)? { - Some(type_val) if type_val.value() == "hash" => { - let hashes_table = read_txn.open_table(HASHES_TABLE)?; - let mut result = Vec::new(); - - let mut iter = hashes_table.iter()?; - while let Some(entry) = iter.next() { - let entry = entry?; - let (hash_key, field) = entry.0.value(); - let value = entry.1.value(); - if hash_key == key { - let decrypted = self.decrypt_if_needed(value)?; - let value_str = String::from_utf8(decrypted)?; - result.push((field.to_string(), value_str)); - } - } - - Ok(result) - } - Some(_) => Err(DBError("WRONGTYPE Operation against a key holding the wrong kind of value".to_string())), - None => Ok(Vec::new()), - } - } - - pub fn hdel(&self, key: &str, fields: &[String]) -> Result { - // Enforce type check before proceeding to write transaction - let key_type = self.get_key_type(key)?; - match key_type.as_deref() { - Some("hash") => { - let write_txn = self.db.begin_write()?; - let mut deleted = 0u64; - { - let mut hashes_table = write_txn.open_table(HASHES_TABLE)?; - for field in fields { - if hashes_table.remove((key, field.as_str()))?.is_some() { - deleted += 1; - } - } - } - write_txn.commit()?; - Ok(deleted) - } - Some(_) => Err(DBError("WRONGTYPE Operation against a key holding the wrong kind of value".to_string())), - None => Ok(0), // Key doesn't exist, so 0 fields deleted. - } - } - - pub fn hexists(&self, key: &str, field: &str) -> Result { - match self.get_key_type(key)?.as_deref() { - Some("hash") => { - let read_txn = self.db.begin_read()?; - let hashes_table = read_txn.open_table(HASHES_TABLE)?; - Ok(hashes_table.get((key, field))?.is_some()) - } - Some(_) => Err(DBError("WRONGTYPE Operation against a key holding the wrong kind of value".to_string())), - None => Ok(false), - } - } - - pub fn hkeys(&self, key: &str) -> Result, DBError> { - match self.get_key_type(key)?.as_deref() { - Some("hash") => { - let read_txn = self.db.begin_read()?; - let hashes_table = read_txn.open_table(HASHES_TABLE)?; - let mut result = Vec::new(); - for entry in hashes_table.range((key, "")..=(key, "\u{FFFF}"))? { - result.push(entry?.0.value().1.to_string()); - } - Ok(result) - } - Some(_) => Err(DBError("WRONGTYPE Operation against a key holding the wrong kind of value".to_string())), - None => Ok(Vec::new()), - } - } - - pub fn hvals(&self, key: &str) -> Result, DBError> { - match self.get_key_type(key)?.as_deref() { - Some("hash") => { - let read_txn = self.db.begin_read()?; - let hashes_table = read_txn.open_table(HASHES_TABLE)?; - let mut result = Vec::new(); - for entry in hashes_table.range((key, "")..=(key, "\u{FFFF}"))? { - let value = self.decrypt_if_needed(entry?.1.value())?; - result.push(String::from_utf8(value)?); - } - Ok(result) - } - Some(_) => Err(DBError("WRONGTYPE Operation against a key holding the wrong kind of value".to_string())), - None => Ok(Vec::new()), - } - } - - pub fn hlen(&self, key: &str) -> Result { - match self.get_key_type(key)?.as_deref() { - Some("hash") => { - let read_txn = self.db.begin_read()?; - let hashes_table = read_txn.open_table(HASHES_TABLE)?; - // Use `range` for efficiency - Ok(hashes_table.range((key, "")..=(key, "\u{FFFF}"))?.count() as u64) - } - Some(_) => Err(DBError("WRONGTYPE Operation against a key holding the wrong kind of value".to_string())), - None => Ok(0), - } - } - - pub fn hmget(&self, key: &str, fields: &[String]) -> Result>, DBError> { - match self.get_key_type(key)?.as_deref() { - Some("hash") => { - let read_txn = self.db.begin_read()?; - let hashes_table = read_txn.open_table(HASHES_TABLE)?; - let mut result = Vec::new(); - for field in fields { - let value = match hashes_table.get((key, field.as_str()))? { - Some(data) => Some(String::from_utf8(self.decrypt_if_needed(data.value())?)?), - None => None, - }; - result.push(value); - } - Ok(result) - } - Some(_) => Err(DBError("WRONGTYPE Operation against a key holding the wrong kind of value".to_string())), - None => Ok(vec![None; fields.len()]), - } - } - - pub fn hsetnx(&self, key: &str, field: &str, value: &str) -> Result { - let write_txn = self.db.begin_write()?; - let mut result = false; - - { - let mut types_table = write_txn.open_table(TYPES_TABLE)?; - let mut hashes_table = write_txn.open_table(HASHES_TABLE)?; - - // Check if key exists and is of correct type - let existing_type = match types_table.get(key)? { - Some(type_val) => Some(type_val.value().to_string()), - None => None, - }; - - match existing_type { - Some(ref type_str) if type_str != "hash" => { - return Err(DBError("WRONGTYPE Operation against a key holding the wrong kind of value".to_string())); - } - None => { - // Set type to hash - types_table.insert(key, "hash")?; - } - _ => {} - } - - // Check if field already exists - if hashes_table.get((key, field))?.is_none() { - let encrypted_value = self.encrypt_if_needed(value.as_bytes())?; - hashes_table.insert((key, field), encrypted_value.as_slice())?; - result = true; - } - } - - write_txn.commit()?; - Ok(result) - } - - pub fn scan(&self, cursor: u64, pattern: Option<&str>, count: Option) -> Result<(u64, Vec), DBError> { - let read_txn = self.db.begin_read()?; - - // Explicitly specify the table type to avoid confusion - let types_table: redb::ReadOnlyTable<&str, &str> = read_txn.open_table(TYPES_TABLE)?; - - let count = count.unwrap_or(10); // Default count is 10 - let mut keys = Vec::new(); - let mut current_cursor = 0u64; - let mut returned_keys = 0u64; - - let mut iter = types_table.iter()?; - while let Some(entry) = iter.next() { - let entry = entry?; - let key = entry.0.value().to_string(); - - // Skip keys until we reach the cursor position - if current_cursor < cursor { - current_cursor += 1; - continue; - } - - // Check if key matches pattern - let matches = match pattern { - Some(pat) => { - if pat == "*" { - true - } else if pat.contains('*') { - // Use the glob_match function for better pattern matching - glob_match(pat, &key) - } else { - key.contains(pat) - } - } - None => true, - }; - - if matches { - keys.push(key); - returned_keys += 1; - - // Stop if we've returned enough keys - if returned_keys >= count { - break; - } - } - - current_cursor += 1; - } - - // If we've reached the end of the iteration, return cursor 0, otherwise return the next cursor position - let next_cursor = if returned_keys < count { 0 } else { current_cursor }; - - Ok((next_cursor, keys)) - } - - pub fn hscan(&self, key: &str, cursor: u64, pattern: Option<&str>, count: Option) -> Result<(u64, Vec), DBError> { - let read_txn = self.db.begin_read()?; - - // Check if key exists and is a hash - let types_table: redb::ReadOnlyTable<&str, &str> = read_txn.open_table(TYPES_TABLE)?; - match types_table.get(key)? { - Some(type_val) if type_val.value() == "hash" => { - let hashes_table: redb::ReadOnlyTable<(&str, &str), &[u8]> = read_txn.open_table(HASHES_TABLE)?; - let count = count.unwrap_or(10); - let mut fields = Vec::new(); - let mut current_cursor = 0u64; - let mut returned_fields = 0u64; - - let mut iter = hashes_table.iter()?; - while let Some(entry) = iter.next() { - let entry = entry?; - let (hash_key, field) = entry.0.value(); - let value = entry.1.value(); - - if hash_key != key { - continue; - } - - // Skip fields until we reach the cursor position - if current_cursor < cursor { - current_cursor += 1; - continue; - } - - // Check if field matches pattern - let matches = match pattern { - Some(pat) => { - if pat == "*" { - true - } else if pat.contains('*') { - // Use the glob_match function for better pattern matching - glob_match(pat, field) - } else { - field.contains(pat) - } - } - None => true, - }; - - if matches { - let decrypted = self.decrypt_if_needed(value)?; - let value_str = String::from_utf8(decrypted)?; - fields.push(field.to_string()); - fields.push(value_str); - returned_fields += 1; - - if returned_fields >= count { - break; - } - } - - current_cursor += 1; - } - - // Check if there are more entries by trying to get the next one - let next_cursor = if returned_fields < count { 0 } else { current_cursor }; - Ok((next_cursor, fields)) - } - Some(_) => Err(DBError("WRONGTYPE Operation against a key holding the wrong kind of value".to_string())), - None => Ok((0, Vec::new())), - } - } - - pub fn ttl(&self, key: &str) -> Result { - let read_txn = self.db.begin_read()?; - - let types_table = read_txn.open_table(TYPES_TABLE)?; - match types_table.get(key)? { - Some(type_val) if type_val.value() == "string" => { - let expiration_table = read_txn.open_table(EXPIRATION_TABLE)?; - match expiration_table.get(key)? { - Some(expires_at) => { - let now = now_in_millis(); - let expires_at = expires_at.value() as u128; - if now > expires_at { - Ok(-2) // Key expired - } else { - Ok(((expires_at - now) / 1000) as i64) // TTL in seconds - } - } - None => Ok(-1), // No expiration - } - } - Some(_) => Ok(-1), // Other types don't have TTL - None => Ok(-2), // Key doesn't exist - } - } - - pub fn exists(&self, key: &str) -> Result { - let read_txn = self.db.begin_read()?; - let types_table = read_txn.open_table(TYPES_TABLE)?; - - match types_table.get(key)? { - Some(type_val) => { - // For string types, check if not expired - if type_val.value() == "string" { - let expiration_table = read_txn.open_table(EXPIRATION_TABLE)?; - if let Some(expires_at) = expiration_table.get(key)? { - if now_in_millis() > expires_at.value() as u128 { - return Ok(false); // Expired - } - } - } - Ok(true) - } - None => Ok(false), - } - } - - // List operations - pub fn lpush(&self, key: &str, elements: Vec) -> Result { - let write_txn = self.db.begin_write()?; - let new_len; - - { - let mut types_table = write_txn.open_table(TYPES_TABLE)?; - let mut lists_table = write_txn.open_table(LISTS_TABLE)?; - - let existing_type = match types_table.get(key)? { - Some(type_val) => Some(type_val.value().to_string()), - None => None, - }; - - match existing_type { - Some(ref type_str) if type_str != "list" => { - return Err(DBError("WRONGTYPE Operation against a key holding the wrong kind of value".to_string())); - } - None => { - types_table.insert(key, "list")?; - } - _ => {} - } - - let mut list_value: ListValue = match lists_table.get(key)? { - Some(data) => { - let decrypted = self.decrypt_if_needed(data.value())?; - bincode::deserialize(&decrypted)? - }, - None => ListValue { elements: Vec::new() }, - }; - - for element in elements.into_iter() { - list_value.elements.insert(0, element); - } - new_len = list_value.elements.len() as u64; - - let serialized = bincode::serialize(&list_value)?; - let encrypted = self.encrypt_if_needed(&serialized)?; - lists_table.insert(key, encrypted.as_slice())?; - } - - write_txn.commit()?; - Ok(new_len) - } - - pub fn rpush(&self, key: &str, elements: Vec) -> Result { - let write_txn = self.db.begin_write()?; - let new_len; - - { - let mut types_table = write_txn.open_table(TYPES_TABLE)?; - let mut lists_table = write_txn.open_table(LISTS_TABLE)?; - - let existing_type = match types_table.get(key)? { - Some(type_val) => Some(type_val.value().to_string()), - None => None, - }; - - match existing_type { - Some(ref type_str) if type_str != "list" => { - return Err(DBError("WRONGTYPE Operation against a key holding the wrong kind of value".to_string())); - } - None => { - types_table.insert(key, "list")?; - } - _ => {} - } - - let mut list_value: ListValue = match lists_table.get(key)? { - Some(data) => { - let decrypted = self.decrypt_if_needed(data.value())?; - bincode::deserialize(&decrypted)? - }, - None => ListValue { elements: Vec::new() }, - }; - - for element in elements { - list_value.elements.push(element); - } - new_len = list_value.elements.len() as u64; - - let serialized = bincode::serialize(&list_value)?; - let encrypted = self.encrypt_if_needed(&serialized)?; - lists_table.insert(key, encrypted.as_slice())?; - } - - write_txn.commit()?; - Ok(new_len) - } - - pub fn lpop(&self, key: &str, count: Option) -> Result>, DBError> { - let write_txn = self.db.begin_write()?; - let mut result_elements = Vec::new(); - - { - let mut types_table = write_txn.open_table(TYPES_TABLE)?; - let mut lists_table = write_txn.open_table(LISTS_TABLE)?; - - let existing_type = match types_table.get(key)? { - Some(type_val) => Some(type_val.value().to_string()), - None => None, - }; - - match existing_type { - Some(ref type_str) if type_str != "list" => { - return Err(DBError("WRONGTYPE Operation against a key holding the wrong kind of value".to_string())); - } - Some(_) => { - let mut list_value: ListValue = match lists_table.get(key)? { - Some(data) => { - let decrypted = self.decrypt_if_needed(data.value())?; - bincode::deserialize(&decrypted)? - }, - None => return Ok(None), // Key exists but list is empty (shouldn't happen if type is "list") - }; - - let num_to_pop = count.unwrap_or(1) as usize; - for _ in 0..num_to_pop { - if !list_value.elements.is_empty() { - result_elements.push(list_value.elements.remove(0)); - } else { - break; - } - } - - if list_value.elements.is_empty() { - lists_table.remove(key)?; - types_table.remove(key)?; - } else { - let serialized = bincode::serialize(&list_value)?; - let encrypted = self.encrypt_if_needed(&serialized)?; - lists_table.insert(key, encrypted.as_slice())?; - } - } - None => return Ok(None), - } - } - - write_txn.commit()?; - if result_elements.is_empty() { - Ok(None) - } else { - Ok(Some(result_elements)) - } - } - - pub fn rpop(&self, key: &str, count: Option) -> Result>, DBError> { - let write_txn = self.db.begin_write()?; - let mut result_elements = Vec::new(); - - { - let mut types_table = write_txn.open_table(TYPES_TABLE)?; - let mut lists_table = write_txn.open_table(LISTS_TABLE)?; - - let existing_type = match types_table.get(key)? { - Some(type_val) => Some(type_val.value().to_string()), - None => None, - }; - - match existing_type { - Some(ref type_str) if type_str != "list" => { - return Err(DBError("WRONGTYPE Operation against a key holding the wrong kind of value".to_string())); - } - Some(_) => { - let mut list_value: ListValue = match lists_table.get(key)? { - Some(data) => { - let decrypted = self.decrypt_if_needed(data.value())?; - bincode::deserialize(&decrypted)? - } - None => return Ok(None), - }; - - let num_to_pop = count.unwrap_or(1) as usize; - for _ in 0..num_to_pop { - if let Some(element) = list_value.elements.pop() { - result_elements.push(element); - } else { - break; - } - } - - if list_value.elements.is_empty() { - lists_table.remove(key)?; - types_table.remove(key)?; - } else { - let serialized = bincode::serialize(&list_value)?; - let encrypted = self.encrypt_if_needed(&serialized)?; - lists_table.insert(key, encrypted.as_slice())?; - } - } - None => return Ok(None), - } - } - - write_txn.commit()?; - if result_elements.is_empty() { - Ok(None) - } else { - Ok(Some(result_elements)) - } - } - - pub fn llen(&self, key: &str) -> Result { - let read_txn = self.db.begin_read()?; - - let types_table = read_txn.open_table(TYPES_TABLE)?; - match types_table.get(key)? { - Some(type_val) if type_val.value() == "list" => { - let lists_table = read_txn.open_table(LISTS_TABLE)?; - match lists_table.get(key)? { - Some(data) => { - let decrypted = self.decrypt_if_needed(data.value())?; - let list_value: ListValue = bincode::deserialize(&decrypted)?; - Ok(list_value.elements.len() as u64) - } - None => Ok(0), // Key exists but list is empty - } - } - Some(_) => Err(DBError("WRONGTYPE Operation against a key holding the wrong kind of value".to_string())), - None => Ok(0), // Key does not exist - } - } - - pub fn lrem(&self, key: &str, count: i64, element: &str) -> Result { - let write_txn = self.db.begin_write()?; - let removed_count; - - { - let mut types_table = write_txn.open_table(TYPES_TABLE)?; - let mut lists_table = write_txn.open_table(LISTS_TABLE)?; - - let existing_type = match types_table.get(key)? { - Some(type_val) => Some(type_val.value().to_string()), - None => None, - }; - - match existing_type { - Some(ref type_str) if type_str != "list" => { - return Err(DBError("WRONGTYPE Operation against a key holding the wrong kind of value".to_string())); - } - Some(_) => { - let mut list_value: ListValue = match lists_table.get(key)? { - Some(data) => { - let decrypted = self.decrypt_if_needed(data.value())?; - bincode::deserialize(&decrypted)? - } - None => return Ok(0), - }; - - let initial_len = list_value.elements.len(); - - if count > 0 { - let mut i = 0; - let mut removed = 0; - while i < list_value.elements.len() && removed < count { - if list_value.elements[i] == element { - list_value.elements.remove(i); - removed += 1; - } else { - i += 1; - } - } - } else if count < 0 { - let mut i = list_value.elements.len() as i32 - 1; - let mut removed = 0; - while i >= 0 && removed < -count { - if list_value.elements[i as usize] == element { - list_value.elements.remove(i as usize); - removed += 1; - } - i -= 1; - } - } else { // count == 0 - list_value.elements.retain(|el| el != element); - } - - removed_count = (initial_len - list_value.elements.len()) as u64; - - if list_value.elements.is_empty() { - lists_table.remove(key)?; - types_table.remove(key)?; - } else { - let serialized = bincode::serialize(&list_value)?; - let encrypted = self.encrypt_if_needed(&serialized)?; - lists_table.insert(key, encrypted.as_slice())?; - } - } - None => return Ok(0), - } - } - - write_txn.commit()?; - Ok(removed_count) - } - - pub fn ltrim(&self, key: &str, start: i64, stop: i64) -> Result<(), DBError> { - let write_txn = self.db.begin_write()?; - - { - let mut types_table = write_txn.open_table(TYPES_TABLE)?; - let mut lists_table = write_txn.open_table(LISTS_TABLE)?; - - let existing_type = match types_table.get(key)? { - Some(type_val) => Some(type_val.value().to_string()), - None => None, - }; - - match existing_type { - Some(ref type_str) if type_str != "list" => { - return Err(DBError("WRONGTYPE Operation against a key holding the wrong kind of value".to_string())); - } - Some(_) => { - let mut list_value: ListValue = match lists_table.get(key)? { - Some(data) => { - let decrypted = self.decrypt_if_needed(data.value())?; - bincode::deserialize(&decrypted)? - } - None => return Ok(()), - }; - - let len = list_value.elements.len() as i64; - let mut start = start; - let mut stop = stop; - - if start < 0 { - start += len; - } - if stop < 0 { - stop += len; - } - - if start < 0 { - start = 0; - } - - if start > stop || start >= len { - list_value.elements.clear(); - } else { - if stop >= len { - stop = len - 1; - } - let start = start as usize; - let stop = stop as usize; - list_value.elements = list_value.elements.drain(start..=stop).collect(); - } - - if list_value.elements.is_empty() { - lists_table.remove(key)?; - types_table.remove(key)?; - } else { - let serialized = bincode::serialize(&list_value)?; - let encrypted = self.encrypt_if_needed(&serialized)?; - lists_table.insert(key, encrypted.as_slice())?; - } - } - None => {} - } - } - - write_txn.commit()?; - Ok(()) - } - - pub fn lindex(&self, key: &str, index: i64) -> Result, DBError> { - let read_txn = self.db.begin_read()?; - - let types_table = read_txn.open_table(TYPES_TABLE)?; - match types_table.get(key)? { - Some(type_val) if type_val.value() == "list" => { - let lists_table = read_txn.open_table(LISTS_TABLE)?; - match lists_table.get(key)? { - Some(data) => { - let decrypted = self.decrypt_if_needed(data.value())?; - let list_value: ListValue = bincode::deserialize(&decrypted)?; - let len = list_value.elements.len() as i64; - let mut index = index; - if index < 0 { - index += len; - } - if index < 0 || index >= len { - Ok(None) - } else { - Ok(list_value.elements.get(index as usize).cloned()) - } - } - None => Ok(None), - } - } - Some(_) => Err(DBError("WRONGTYPE Operation against a key holding the wrong kind of value".to_string())), - None => Ok(None), - } - } - - pub fn lrange(&self, key: &str, start: i64, stop: i64) -> Result, DBError> { - let read_txn = self.db.begin_read()?; - - let types_table = read_txn.open_table(TYPES_TABLE)?; - match types_table.get(key)? { - Some(type_val) if type_val.value() == "list" => { - let lists_table = read_txn.open_table(LISTS_TABLE)?; - match lists_table.get(key)? { - Some(data) => { - let decrypted = self.decrypt_if_needed(data.value())?; - let list_value: ListValue = bincode::deserialize(&decrypted)?; - let len = list_value.elements.len() as i64; - let mut start = start; - let mut stop = stop; - - if start < 0 { - start += len; - } - if stop < 0 { - stop += len; - } - - if start < 0 { - start = 0; - } - - if start > stop || start >= len { - Ok(Vec::new()) - } else { - if stop >= len { - stop = len - 1; - } - Ok(list_value.elements[start as usize..=stop as usize].to_vec()) - } - } - None => Ok(Vec::new()), - } - } - Some(_) => Err(DBError("WRONGTYPE Operation against a key holding the wrong kind of value".to_string())), - None => Ok(Vec::new()), - } - } -} \ No newline at end of file diff --git a/src/storage/mod.rs b/src/storage/mod.rs new file mode 100644 index 0000000..7c33028 --- /dev/null +++ b/src/storage/mod.rs @@ -0,0 +1,126 @@ +use std::{ + path::Path, + time::{SystemTime, UNIX_EPOCH}, +}; + +use redb::{Database, TableDefinition}; +use serde::{Deserialize, Serialize}; + +use crate::crypto::CryptoFactory; +use crate::error::DBError; + +// Re-export modules +mod storage_basic; +mod storage_hset; +mod storage_lists; +mod storage_extra; + +// Re-export implementations +// Note: These imports are used by the impl blocks in the submodules +// The compiler shows them as unused because they're not directly used in this file +// but they're needed for the Storage struct methods to be available +pub use storage_extra::*; + +// Table definitions for different Redis data types +const TYPES_TABLE: TableDefinition<&str, &str> = TableDefinition::new("types"); +const STRINGS_TABLE: TableDefinition<&str, &[u8]> = TableDefinition::new("strings"); +const HASHES_TABLE: TableDefinition<(&str, &str), &[u8]> = TableDefinition::new("hashes"); +const LISTS_TABLE: TableDefinition<&str, &[u8]> = TableDefinition::new("lists"); +const STREAMS_META_TABLE: TableDefinition<&str, &[u8]> = TableDefinition::new("streams_meta"); +const STREAMS_DATA_TABLE: TableDefinition<(&str, &str), &[u8]> = TableDefinition::new("streams_data"); +const ENCRYPTED_TABLE: TableDefinition<&str, u8> = TableDefinition::new("encrypted"); +const EXPIRATION_TABLE: TableDefinition<&str, u64> = TableDefinition::new("expiration"); + +#[derive(Serialize, Deserialize, Debug, Clone)] +pub struct StreamEntry { + pub fields: Vec<(String, String)>, +} + +#[derive(Serialize, Deserialize, Debug, Clone)] +pub struct ListValue { + pub elements: Vec, +} + +#[inline] +pub fn now_in_millis() -> u128 { + let start = SystemTime::now(); + let duration_since_epoch = start.duration_since(UNIX_EPOCH).unwrap(); + duration_since_epoch.as_millis() +} + +pub struct Storage { + db: Database, + crypto: Option, +} + +impl Storage { + pub fn new(path: impl AsRef, should_encrypt: bool, master_key: Option<&str>) -> Result { + let db = Database::create(path)?; + + // Create tables if they don't exist + let write_txn = db.begin_write()?; + { + let _ = write_txn.open_table(TYPES_TABLE)?; + let _ = write_txn.open_table(STRINGS_TABLE)?; + let _ = write_txn.open_table(HASHES_TABLE)?; + let _ = write_txn.open_table(LISTS_TABLE)?; + let _ = write_txn.open_table(STREAMS_META_TABLE)?; + let _ = write_txn.open_table(STREAMS_DATA_TABLE)?; + let _ = write_txn.open_table(ENCRYPTED_TABLE)?; + let _ = write_txn.open_table(EXPIRATION_TABLE)?; + } + write_txn.commit()?; + + // Check if database was previously encrypted + let read_txn = db.begin_read()?; + let encrypted_table = read_txn.open_table(ENCRYPTED_TABLE)?; + let was_encrypted = encrypted_table.get("encrypted")?.map(|v| v.value() == 1).unwrap_or(false); + drop(read_txn); + + let crypto = if should_encrypt || was_encrypted { + if let Some(key) = master_key { + Some(CryptoFactory::new(key.as_bytes())) + } else { + return Err(DBError("Encryption requested but no master key provided".to_string())); + } + } else { + None + }; + + // If we're enabling encryption for the first time, mark it + if should_encrypt && !was_encrypted { + let write_txn = db.begin_write()?; + { + let mut encrypted_table = write_txn.open_table(ENCRYPTED_TABLE)?; + encrypted_table.insert("encrypted", &1u8)?; + } + write_txn.commit()?; + } + + Ok(Storage { + db, + crypto, + }) + } + + pub fn is_encrypted(&self) -> bool { + self.crypto.is_some() + } + + // Helper methods for encryption + fn encrypt_if_needed(&self, data: &[u8]) -> Result, DBError> { + if let Some(crypto) = &self.crypto { + Ok(crypto.encrypt(data)) + } else { + Ok(data.to_vec()) + } + } + + fn decrypt_if_needed(&self, data: &[u8]) -> Result, DBError> { + if let Some(crypto) = &self.crypto { + Ok(crypto.decrypt(data)?) + } else { + Ok(data.to_vec()) + } + } +} \ No newline at end of file diff --git a/src/storage/storage_basic.rs b/src/storage/storage_basic.rs new file mode 100644 index 0000000..a394cb7 --- /dev/null +++ b/src/storage/storage_basic.rs @@ -0,0 +1,218 @@ +use redb::{ReadableTable}; +use crate::error::DBError; +use super::*; + +impl Storage { + pub fn flushdb(&self) -> Result<(), DBError> { + let write_txn = self.db.begin_write()?; + { + let mut types_table = write_txn.open_table(TYPES_TABLE)?; + let mut strings_table = write_txn.open_table(STRINGS_TABLE)?; + let mut hashes_table = write_txn.open_table(HASHES_TABLE)?; + let mut lists_table = write_txn.open_table(LISTS_TABLE)?; + let mut streams_meta_table = write_txn.open_table(STREAMS_META_TABLE)?; + let mut streams_data_table = write_txn.open_table(STREAMS_DATA_TABLE)?; + let mut expiration_table = write_txn.open_table(EXPIRATION_TABLE)?; + + // inefficient, but there is no other way + let keys: Vec = types_table.iter()?.map(|item| item.unwrap().0.value().to_string()).collect(); + for key in keys { + types_table.remove(key.as_str())?; + } + let keys: Vec = strings_table.iter()?.map(|item| item.unwrap().0.value().to_string()).collect(); + for key in keys { + strings_table.remove(key.as_str())?; + } + let keys: Vec<(String, String)> = hashes_table + .iter()? + .map(|item| { + let binding = item.unwrap(); + let (k, f) = binding.0.value(); + (k.to_string(), f.to_string()) + }) + .collect(); + for (key, field) in keys { + hashes_table.remove((key.as_str(), field.as_str()))?; + } + let keys: Vec = lists_table.iter()?.map(|item| item.unwrap().0.value().to_string()).collect(); + for key in keys { + lists_table.remove(key.as_str())?; + } + let keys: Vec = streams_meta_table.iter()?.map(|item| item.unwrap().0.value().to_string()).collect(); + for key in keys { + streams_meta_table.remove(key.as_str())?; + } + let keys: Vec<(String,String)> = streams_data_table.iter()?.map(|item| { + let binding = item.unwrap(); + let (key, field) = binding.0.value(); + (key.to_string(), field.to_string()) + }).collect(); + for (key, field) in keys { + streams_data_table.remove((key.as_str(), field.as_str()))?; + } + let keys: Vec = expiration_table.iter()?.map(|item| item.unwrap().0.value().to_string()).collect(); + for key in keys { + expiration_table.remove(key.as_str())?; + } + } + write_txn.commit()?; + Ok(()) + } + + pub fn get_key_type(&self, key: &str) -> Result, DBError> { + let read_txn = self.db.begin_read()?; + let table = read_txn.open_table(TYPES_TABLE)?; + + // Before returning type, check for expiration + if let Some(type_val) = table.get(key)? { + if type_val.value() == "string" { + let expiration_table = read_txn.open_table(EXPIRATION_TABLE)?; + if let Some(expires_at) = expiration_table.get(key)? { + if now_in_millis() > expires_at.value() as u128 { + // The key is expired, so it effectively has no type + return Ok(None); + } + } + } + Ok(Some(type_val.value().to_string())) + } else { + Ok(None) + } + } + + // ✅ ENCRYPTION APPLIED: Value is encrypted/decrypted + pub fn get(&self, key: &str) -> Result, DBError> { + let read_txn = self.db.begin_read()?; + + let types_table = read_txn.open_table(TYPES_TABLE)?; + match types_table.get(key)? { + Some(type_val) if type_val.value() == "string" => { + // Check expiration first (unencrypted) + let expiration_table = read_txn.open_table(EXPIRATION_TABLE)?; + if let Some(expires_at) = expiration_table.get(key)? { + if now_in_millis() > expires_at.value() as u128 { + drop(read_txn); + self.del(key.to_string())?; + return Ok(None); + } + } + + // Get and decrypt value + let strings_table = read_txn.open_table(STRINGS_TABLE)?; + match strings_table.get(key)? { + Some(data) => { + let decrypted = self.decrypt_if_needed(data.value())?; + let value = String::from_utf8(decrypted)?; + Ok(Some(value)) + } + None => Ok(None), + } + } + _ => Ok(None), + } + } + + // ✅ ENCRYPTION APPLIED: Value is encrypted before storage + pub fn set(&self, key: String, value: String) -> Result<(), DBError> { + let write_txn = self.db.begin_write()?; + + { + let mut types_table = write_txn.open_table(TYPES_TABLE)?; + types_table.insert(key.as_str(), "string")?; + + let mut strings_table = write_txn.open_table(STRINGS_TABLE)?; + // Only encrypt the value, not expiration + let encrypted = self.encrypt_if_needed(value.as_bytes())?; + strings_table.insert(key.as_str(), encrypted.as_slice())?; + + // Remove any existing expiration since this is a regular SET + let mut expiration_table = write_txn.open_table(EXPIRATION_TABLE)?; + expiration_table.remove(key.as_str())?; + } + + write_txn.commit()?; + Ok(()) + } + + // ✅ ENCRYPTION APPLIED: Value is encrypted before storage + pub fn setx(&self, key: String, value: String, expire_ms: u128) -> Result<(), DBError> { + let write_txn = self.db.begin_write()?; + + { + let mut types_table = write_txn.open_table(TYPES_TABLE)?; + types_table.insert(key.as_str(), "string")?; + + let mut strings_table = write_txn.open_table(STRINGS_TABLE)?; + // Only encrypt the value + let encrypted = self.encrypt_if_needed(value.as_bytes())?; + strings_table.insert(key.as_str(), encrypted.as_slice())?; + + // Store expiration separately (unencrypted) + let mut expiration_table = write_txn.open_table(EXPIRATION_TABLE)?; + let expires_at = expire_ms + now_in_millis(); + expiration_table.insert(key.as_str(), &(expires_at as u64))?; + } + + write_txn.commit()?; + Ok(()) + } + + pub fn del(&self, key: String) -> Result<(), DBError> { + let write_txn = self.db.begin_write()?; + + { + let mut types_table = write_txn.open_table(TYPES_TABLE)?; + let mut strings_table = write_txn.open_table(STRINGS_TABLE)?; + let mut hashes_table: redb::Table<(&str, &str), &[u8]> = write_txn.open_table(HASHES_TABLE)?; + let mut lists_table = write_txn.open_table(LISTS_TABLE)?; + + // Remove from type table + types_table.remove(key.as_str())?; + + // Remove from strings table + strings_table.remove(key.as_str())?; + + // Remove all hash fields for this key + let mut to_remove = Vec::new(); + let mut iter = hashes_table.iter()?; + while let Some(entry) = iter.next() { + let entry = entry?; + let (hash_key, field) = entry.0.value(); + if hash_key == key.as_str() { + to_remove.push((hash_key.to_string(), field.to_string())); + } + } + drop(iter); + + for (hash_key, field) in to_remove { + hashes_table.remove((hash_key.as_str(), field.as_str()))?; + } + + // Remove from lists table + lists_table.remove(key.as_str())?; + + // Also remove expiration + let mut expiration_table = write_txn.open_table(EXPIRATION_TABLE)?; + expiration_table.remove(key.as_str())?; + } + + write_txn.commit()?; + Ok(()) + } + + pub fn keys(&self, pattern: &str) -> Result, DBError> { + let read_txn = self.db.begin_read()?; + let table = read_txn.open_table(TYPES_TABLE)?; + + let mut keys = Vec::new(); + let mut iter = table.iter()?; + while let Some(entry) = iter.next() { + let key = entry?.0.value().to_string(); + if pattern == "*" || super::storage_extra::glob_match(pattern, &key) { + keys.push(key); + } + } + + Ok(keys) + } +} \ No newline at end of file diff --git a/src/storage/storage_extra.rs b/src/storage/storage_extra.rs new file mode 100644 index 0000000..cb8aa25 --- /dev/null +++ b/src/storage/storage_extra.rs @@ -0,0 +1,168 @@ +use redb::{ReadableTable}; +use crate::error::DBError; +use super::*; + +impl Storage { + // ✅ ENCRYPTION APPLIED: Values are decrypted after retrieval + pub fn scan(&self, cursor: u64, pattern: Option<&str>, count: Option) -> Result<(u64, Vec<(String, String)>), DBError> { + let read_txn = self.db.begin_read()?; + let types_table = read_txn.open_table(TYPES_TABLE)?; + let strings_table = read_txn.open_table(STRINGS_TABLE)?; + + let mut result = Vec::new(); + let mut current_cursor = 0u64; + let limit = count.unwrap_or(10) as usize; + + let mut iter = types_table.iter()?; + while let Some(entry) = iter.next() { + let entry = entry?; + let key = entry.0.value().to_string(); + let key_type = entry.1.value().to_string(); + + if current_cursor >= cursor { + // Apply pattern matching if specified + let matches = if let Some(pat) = pattern { + glob_match(pat, &key) + } else { + true + }; + + if matches { + // For scan, we return key-value pairs for string types + if key_type == "string" { + if let Some(data) = strings_table.get(key.as_str())? { + let decrypted = self.decrypt_if_needed(data.value())?; + let value = String::from_utf8(decrypted)?; + result.push((key, value)); + } else { + result.push((key, String::new())); + } + } else { + // For non-string types, just return the key with type as value + result.push((key, key_type)); + } + + if result.len() >= limit { + break; + } + } + } + current_cursor += 1; + } + + let next_cursor = if result.len() < limit { 0 } else { current_cursor }; + Ok((next_cursor, result)) + } + + pub fn ttl(&self, key: &str) -> Result { + let read_txn = self.db.begin_read()?; + let types_table = read_txn.open_table(TYPES_TABLE)?; + + match types_table.get(key)? { + Some(type_val) if type_val.value() == "string" => { + let expiration_table = read_txn.open_table(EXPIRATION_TABLE)?; + match expiration_table.get(key)? { + Some(expires_at) => { + let now = now_in_millis(); + let expires_at_ms = expires_at.value() as u128; + if now >= expires_at_ms { + Ok(-2) // Key has expired + } else { + Ok(((expires_at_ms - now) / 1000) as i64) // TTL in seconds + } + } + None => Ok(-1), // Key exists but has no expiration + } + } + Some(_) => Ok(-1), // Key exists but is not a string (no expiration support for other types) + None => Ok(-2), // Key does not exist + } + } + + pub fn exists(&self, key: &str) -> Result { + let read_txn = self.db.begin_read()?; + let types_table = read_txn.open_table(TYPES_TABLE)?; + + match types_table.get(key)? { + Some(type_val) if type_val.value() == "string" => { + // Check if string key has expired + let expiration_table = read_txn.open_table(EXPIRATION_TABLE)?; + if let Some(expires_at) = expiration_table.get(key)? { + if now_in_millis() > expires_at.value() as u128 { + return Ok(false); // Key has expired + } + } + Ok(true) + } + Some(_) => Ok(true), // Key exists and is not a string + None => Ok(false), // Key does not exist + } + } +} + +// Utility function for glob pattern matching +pub fn glob_match(pattern: &str, text: &str) -> bool { + if pattern == "*" { + return true; + } + + // Simple glob matching - supports * and ? wildcards + let pattern_chars: Vec = pattern.chars().collect(); + let text_chars: Vec = text.chars().collect(); + + fn match_recursive(pattern: &[char], text: &[char], pi: usize, ti: usize) -> bool { + if pi >= pattern.len() { + return ti >= text.len(); + } + + if ti >= text.len() { + // Check if remaining pattern is all '*' + return pattern[pi..].iter().all(|&c| c == '*'); + } + + match pattern[pi] { + '*' => { + // Try matching zero or more characters + for i in ti..=text.len() { + if match_recursive(pattern, text, pi + 1, i) { + return true; + } + } + false + } + '?' => { + // Match exactly one character + match_recursive(pattern, text, pi + 1, ti + 1) + } + c => { + // Match exact character + if text[ti] == c { + match_recursive(pattern, text, pi + 1, ti + 1) + } else { + false + } + } + } + } + + match_recursive(&pattern_chars, &text_chars, 0, 0) +} + +#[cfg(test)] +mod tests { + use super::*; + + #[test] + fn test_glob_match() { + assert!(glob_match("*", "anything")); + assert!(glob_match("hello", "hello")); + assert!(!glob_match("hello", "world")); + assert!(glob_match("h*o", "hello")); + assert!(glob_match("h*o", "ho")); + assert!(!glob_match("h*o", "hi")); + assert!(glob_match("h?llo", "hello")); + assert!(!glob_match("h?llo", "hllo")); + assert!(glob_match("*test*", "this_is_a_test_string")); + assert!(!glob_match("*test*", "this_is_a_string")); + } +} \ No newline at end of file diff --git a/src/storage/storage_hset.rs b/src/storage/storage_hset.rs new file mode 100644 index 0000000..b9ae3c4 --- /dev/null +++ b/src/storage/storage_hset.rs @@ -0,0 +1,318 @@ +use redb::{ReadableTable}; +use crate::error::DBError; +use super::*; + +impl Storage { + // ✅ ENCRYPTION APPLIED: Values are encrypted before storage + pub fn hset(&self, key: &str, pairs: Vec<(String, String)>) -> Result { + let write_txn = self.db.begin_write()?; + let mut new_fields = 0i64; + + { + let mut types_table = write_txn.open_table(TYPES_TABLE)?; + let mut hashes_table = write_txn.open_table(HASHES_TABLE)?; + + // Set the type to hash + types_table.insert(key, "hash")?; + + for (field, value) in pairs { + // Check if field already exists + let exists = hashes_table.get((key, field.as_str()))?.is_some(); + + // Encrypt the value before storing + let encrypted = self.encrypt_if_needed(value.as_bytes())?; + hashes_table.insert((key, field.as_str()), encrypted.as_slice())?; + + if !exists { + new_fields += 1; + } + } + } + + write_txn.commit()?; + Ok(new_fields) + } + + // ✅ ENCRYPTION APPLIED: Value is decrypted after retrieval + pub fn hget(&self, key: &str, field: &str) -> Result, DBError> { + let read_txn = self.db.begin_read()?; + let types_table = read_txn.open_table(TYPES_TABLE)?; + + match types_table.get(key)? { + Some(type_val) if type_val.value() == "hash" => { + let hashes_table = read_txn.open_table(HASHES_TABLE)?; + match hashes_table.get((key, field))? { + Some(data) => { + let decrypted = self.decrypt_if_needed(data.value())?; + let value = String::from_utf8(decrypted)?; + Ok(Some(value)) + } + None => Ok(None), + } + } + _ => Ok(None), + } + } + + // ✅ ENCRYPTION APPLIED: All values are decrypted after retrieval + pub fn hgetall(&self, key: &str) -> Result, DBError> { + let read_txn = self.db.begin_read()?; + let types_table = read_txn.open_table(TYPES_TABLE)?; + + match types_table.get(key)? { + Some(type_val) if type_val.value() == "hash" => { + let hashes_table = read_txn.open_table(HASHES_TABLE)?; + let mut result = Vec::new(); + + let mut iter = hashes_table.iter()?; + while let Some(entry) = iter.next() { + let entry = entry?; + let (hash_key, field) = entry.0.value(); + if hash_key == key { + let decrypted = self.decrypt_if_needed(entry.1.value())?; + let value = String::from_utf8(decrypted)?; + result.push((field.to_string(), value)); + } + } + + Ok(result) + } + _ => Ok(Vec::new()), + } + } + + pub fn hdel(&self, key: &str, fields: Vec) -> Result { + let write_txn = self.db.begin_write()?; + let mut deleted = 0i64; + + // First check if key exists and is a hash + let is_hash = { + let types_table = write_txn.open_table(TYPES_TABLE)?; + let result = match types_table.get(key)? { + Some(type_val) => type_val.value() == "hash", + None => false, + }; + result + }; + + if is_hash { + let mut hashes_table = write_txn.open_table(HASHES_TABLE)?; + + for field in fields { + if hashes_table.remove((key, field.as_str()))?.is_some() { + deleted += 1; + } + } + + // Check if hash is now empty and remove type if so + let mut has_fields = false; + let mut iter = hashes_table.iter()?; + while let Some(entry) = iter.next() { + let entry = entry?; + let (hash_key, _) = entry.0.value(); + if hash_key == key { + has_fields = true; + break; + } + } + drop(iter); + + if !has_fields { + let mut types_table = write_txn.open_table(TYPES_TABLE)?; + types_table.remove(key)?; + } + } + + write_txn.commit()?; + Ok(deleted) + } + + pub fn hexists(&self, key: &str, field: &str) -> Result { + let read_txn = self.db.begin_read()?; + let types_table = read_txn.open_table(TYPES_TABLE)?; + + match types_table.get(key)? { + Some(type_val) if type_val.value() == "hash" => { + let hashes_table = read_txn.open_table(HASHES_TABLE)?; + Ok(hashes_table.get((key, field))?.is_some()) + } + _ => Ok(false), + } + } + + pub fn hkeys(&self, key: &str) -> Result, DBError> { + let read_txn = self.db.begin_read()?; + let types_table = read_txn.open_table(TYPES_TABLE)?; + + match types_table.get(key)? { + Some(type_val) if type_val.value() == "hash" => { + let hashes_table = read_txn.open_table(HASHES_TABLE)?; + let mut result = Vec::new(); + + let mut iter = hashes_table.iter()?; + while let Some(entry) = iter.next() { + let entry = entry?; + let (hash_key, field) = entry.0.value(); + if hash_key == key { + result.push(field.to_string()); + } + } + + Ok(result) + } + _ => Ok(Vec::new()), + } + } + + // ✅ ENCRYPTION APPLIED: All values are decrypted after retrieval + pub fn hvals(&self, key: &str) -> Result, DBError> { + let read_txn = self.db.begin_read()?; + let types_table = read_txn.open_table(TYPES_TABLE)?; + + match types_table.get(key)? { + Some(type_val) if type_val.value() == "hash" => { + let hashes_table = read_txn.open_table(HASHES_TABLE)?; + let mut result = Vec::new(); + + let mut iter = hashes_table.iter()?; + while let Some(entry) = iter.next() { + let entry = entry?; + let (hash_key, _) = entry.0.value(); + if hash_key == key { + let decrypted = self.decrypt_if_needed(entry.1.value())?; + let value = String::from_utf8(decrypted)?; + result.push(value); + } + } + + Ok(result) + } + _ => Ok(Vec::new()), + } + } + + pub fn hlen(&self, key: &str) -> Result { + let read_txn = self.db.begin_read()?; + let types_table = read_txn.open_table(TYPES_TABLE)?; + + match types_table.get(key)? { + Some(type_val) if type_val.value() == "hash" => { + let hashes_table = read_txn.open_table(HASHES_TABLE)?; + let mut count = 0i64; + + let mut iter = hashes_table.iter()?; + while let Some(entry) = iter.next() { + let entry = entry?; + let (hash_key, _) = entry.0.value(); + if hash_key == key { + count += 1; + } + } + + Ok(count) + } + _ => Ok(0), + } + } + + // ✅ ENCRYPTION APPLIED: Values are decrypted after retrieval + pub fn hmget(&self, key: &str, fields: Vec) -> Result>, DBError> { + let read_txn = self.db.begin_read()?; + let types_table = read_txn.open_table(TYPES_TABLE)?; + + match types_table.get(key)? { + Some(type_val) if type_val.value() == "hash" => { + let hashes_table = read_txn.open_table(HASHES_TABLE)?; + let mut result = Vec::new(); + + for field in fields { + match hashes_table.get((key, field.as_str()))? { + Some(data) => { + let decrypted = self.decrypt_if_needed(data.value())?; + let value = String::from_utf8(decrypted)?; + result.push(Some(value)); + } + None => result.push(None), + } + } + + Ok(result) + } + _ => Ok(fields.into_iter().map(|_| None).collect()), + } + } + + // ✅ ENCRYPTION APPLIED: Value is encrypted before storage + pub fn hsetnx(&self, key: &str, field: &str, value: &str) -> Result { + let write_txn = self.db.begin_write()?; + let mut result = false; + + { + let mut types_table = write_txn.open_table(TYPES_TABLE)?; + let mut hashes_table = write_txn.open_table(HASHES_TABLE)?; + + // Check if field already exists + if hashes_table.get((key, field))?.is_none() { + // Set the type to hash + types_table.insert(key, "hash")?; + + // Encrypt the value before storing + let encrypted = self.encrypt_if_needed(value.as_bytes())?; + hashes_table.insert((key, field), encrypted.as_slice())?; + result = true; + } + } + + write_txn.commit()?; + Ok(result) + } + + // ✅ ENCRYPTION APPLIED: Values are decrypted after retrieval + pub fn hscan(&self, key: &str, cursor: u64, pattern: Option<&str>, count: Option) -> Result<(u64, Vec<(String, String)>), DBError> { + let read_txn = self.db.begin_read()?; + let types_table = read_txn.open_table(TYPES_TABLE)?; + + match types_table.get(key)? { + Some(type_val) if type_val.value() == "hash" => { + let hashes_table = read_txn.open_table(HASHES_TABLE)?; + let mut result = Vec::new(); + let mut current_cursor = 0u64; + let limit = count.unwrap_or(10) as usize; + + let mut iter = hashes_table.iter()?; + while let Some(entry) = iter.next() { + let entry = entry?; + let (hash_key, field) = entry.0.value(); + + if hash_key == key { + if current_cursor >= cursor { + let field_str = field.to_string(); + + // Apply pattern matching if specified + let matches = if let Some(pat) = pattern { + super::storage_extra::glob_match(pat, &field_str) + } else { + true + }; + + if matches { + let decrypted = self.decrypt_if_needed(entry.1.value())?; + let value = String::from_utf8(decrypted)?; + result.push((field_str, value)); + + if result.len() >= limit { + break; + } + } + } + current_cursor += 1; + } + } + + let next_cursor = if result.len() < limit { 0 } else { current_cursor }; + Ok((next_cursor, result)) + } + _ => Ok((0, Vec::new())), + } + } +} \ No newline at end of file diff --git a/src/storage/storage_lists.rs b/src/storage/storage_lists.rs new file mode 100644 index 0000000..6dfd381 --- /dev/null +++ b/src/storage/storage_lists.rs @@ -0,0 +1,403 @@ +use redb::{ReadableTable}; +use crate::error::DBError; +use super::*; + +impl Storage { + // ✅ ENCRYPTION APPLIED: Elements are encrypted before storage + pub fn lpush(&self, key: &str, elements: Vec) -> Result { + let write_txn = self.db.begin_write()?; + let mut _length = 0i64; + + { + let mut types_table = write_txn.open_table(TYPES_TABLE)?; + let mut lists_table = write_txn.open_table(LISTS_TABLE)?; + + // Set the type to list + types_table.insert(key, "list")?; + + // Get current list or create empty one + let mut list: Vec = match lists_table.get(key)? { + Some(data) => { + let decrypted = self.decrypt_if_needed(data.value())?; + serde_json::from_slice(&decrypted)? + } + None => Vec::new(), + }; + + // Add elements to the front (left) + for element in elements.into_iter().rev() { + list.insert(0, element); + } + + _length = list.len() as i64; + + // Encrypt and store the updated list + let serialized = serde_json::to_vec(&list)?; + let encrypted = self.encrypt_if_needed(&serialized)?; + lists_table.insert(key, encrypted.as_slice())?; + } + + write_txn.commit()?; + Ok(_length) + } + + // ✅ ENCRYPTION APPLIED: Elements are encrypted before storage + pub fn rpush(&self, key: &str, elements: Vec) -> Result { + let write_txn = self.db.begin_write()?; + let mut _length = 0i64; + + { + let mut types_table = write_txn.open_table(TYPES_TABLE)?; + let mut lists_table = write_txn.open_table(LISTS_TABLE)?; + + // Set the type to list + types_table.insert(key, "list")?; + + // Get current list or create empty one + let mut list: Vec = match lists_table.get(key)? { + Some(data) => { + let decrypted = self.decrypt_if_needed(data.value())?; + serde_json::from_slice(&decrypted)? + } + None => Vec::new(), + }; + + // Add elements to the end (right) + list.extend(elements); + _length = list.len() as i64; + + // Encrypt and store the updated list + let serialized = serde_json::to_vec(&list)?; + let encrypted = self.encrypt_if_needed(&serialized)?; + lists_table.insert(key, encrypted.as_slice())?; + } + + write_txn.commit()?; + Ok(_length) + } + + // ✅ ENCRYPTION APPLIED: Elements are decrypted after retrieval and encrypted before storage + pub fn lpop(&self, key: &str, count: u64) -> Result, DBError> { + let write_txn = self.db.begin_write()?; + let mut result = Vec::new(); + + // First check if key exists and is a list, and get the data + let list_data = { + let types_table = write_txn.open_table(TYPES_TABLE)?; + let lists_table = write_txn.open_table(LISTS_TABLE)?; + + let result = match types_table.get(key)? { + Some(type_val) if type_val.value() == "list" => { + if let Some(data) = lists_table.get(key)? { + let decrypted = self.decrypt_if_needed(data.value())?; + let list: Vec = serde_json::from_slice(&decrypted)?; + Some(list) + } else { + None + } + } + _ => None, + }; + result + }; + + if let Some(mut list) = list_data { + let pop_count = std::cmp::min(count as usize, list.len()); + for _ in 0..pop_count { + if !list.is_empty() { + result.push(list.remove(0)); + } + } + + let mut lists_table = write_txn.open_table(LISTS_TABLE)?; + if list.is_empty() { + // Remove the key if list is empty + lists_table.remove(key)?; + let mut types_table = write_txn.open_table(TYPES_TABLE)?; + types_table.remove(key)?; + } else { + // Encrypt and store the updated list + let serialized = serde_json::to_vec(&list)?; + let encrypted = self.encrypt_if_needed(&serialized)?; + lists_table.insert(key, encrypted.as_slice())?; + } + } + + write_txn.commit()?; + Ok(result) + } + + // ✅ ENCRYPTION APPLIED: Elements are decrypted after retrieval and encrypted before storage + pub fn rpop(&self, key: &str, count: u64) -> Result, DBError> { + let write_txn = self.db.begin_write()?; + let mut result = Vec::new(); + + // First check if key exists and is a list, and get the data + let list_data = { + let types_table = write_txn.open_table(TYPES_TABLE)?; + let lists_table = write_txn.open_table(LISTS_TABLE)?; + + let result = match types_table.get(key)? { + Some(type_val) if type_val.value() == "list" => { + if let Some(data) = lists_table.get(key)? { + let decrypted = self.decrypt_if_needed(data.value())?; + let list: Vec = serde_json::from_slice(&decrypted)?; + Some(list) + } else { + None + } + } + _ => None, + }; + result + }; + + if let Some(mut list) = list_data { + let pop_count = std::cmp::min(count as usize, list.len()); + for _ in 0..pop_count { + if !list.is_empty() { + result.push(list.pop().unwrap()); + } + } + + let mut lists_table = write_txn.open_table(LISTS_TABLE)?; + if list.is_empty() { + // Remove the key if list is empty + lists_table.remove(key)?; + let mut types_table = write_txn.open_table(TYPES_TABLE)?; + types_table.remove(key)?; + } else { + // Encrypt and store the updated list + let serialized = serde_json::to_vec(&list)?; + let encrypted = self.encrypt_if_needed(&serialized)?; + lists_table.insert(key, encrypted.as_slice())?; + } + } + + write_txn.commit()?; + Ok(result) + } + + pub fn llen(&self, key: &str) -> Result { + let read_txn = self.db.begin_read()?; + let types_table = read_txn.open_table(TYPES_TABLE)?; + + match types_table.get(key)? { + Some(type_val) if type_val.value() == "list" => { + let lists_table = read_txn.open_table(LISTS_TABLE)?; + match lists_table.get(key)? { + Some(data) => { + let decrypted = self.decrypt_if_needed(data.value())?; + let list: Vec = serde_json::from_slice(&decrypted)?; + Ok(list.len() as i64) + } + None => Ok(0), + } + } + _ => Ok(0), + } + } + + // ✅ ENCRYPTION APPLIED: Element is decrypted after retrieval + pub fn lindex(&self, key: &str, index: i64) -> Result, DBError> { + let read_txn = self.db.begin_read()?; + let types_table = read_txn.open_table(TYPES_TABLE)?; + + match types_table.get(key)? { + Some(type_val) if type_val.value() == "list" => { + let lists_table = read_txn.open_table(LISTS_TABLE)?; + match lists_table.get(key)? { + Some(data) => { + let decrypted = self.decrypt_if_needed(data.value())?; + let list: Vec = serde_json::from_slice(&decrypted)?; + + let actual_index = if index < 0 { + list.len() as i64 + index + } else { + index + }; + + if actual_index >= 0 && (actual_index as usize) < list.len() { + Ok(Some(list[actual_index as usize].clone())) + } else { + Ok(None) + } + } + None => Ok(None), + } + } + _ => Ok(None), + } + } + + // ✅ ENCRYPTION APPLIED: Elements are decrypted after retrieval + pub fn lrange(&self, key: &str, start: i64, stop: i64) -> Result, DBError> { + let read_txn = self.db.begin_read()?; + let types_table = read_txn.open_table(TYPES_TABLE)?; + + match types_table.get(key)? { + Some(type_val) if type_val.value() == "list" => { + let lists_table = read_txn.open_table(LISTS_TABLE)?; + match lists_table.get(key)? { + Some(data) => { + let decrypted = self.decrypt_if_needed(data.value())?; + let list: Vec = serde_json::from_slice(&decrypted)?; + + if list.is_empty() { + return Ok(Vec::new()); + } + + let len = list.len() as i64; + let start_idx = if start < 0 { std::cmp::max(0, len + start) } else { std::cmp::min(start, len) }; + let stop_idx = if stop < 0 { std::cmp::max(-1, len + stop) } else { std::cmp::min(stop, len - 1) }; + + if start_idx > stop_idx || start_idx >= len { + return Ok(Vec::new()); + } + + let start_usize = start_idx as usize; + let stop_usize = (stop_idx + 1) as usize; + + Ok(list[start_usize..std::cmp::min(stop_usize, list.len())].to_vec()) + } + None => Ok(Vec::new()), + } + } + _ => Ok(Vec::new()), + } + } + + // ✅ ENCRYPTION APPLIED: Elements are decrypted after retrieval and encrypted before storage + pub fn ltrim(&self, key: &str, start: i64, stop: i64) -> Result<(), DBError> { + let write_txn = self.db.begin_write()?; + + // First check if key exists and is a list, and get the data + let list_data = { + let types_table = write_txn.open_table(TYPES_TABLE)?; + let lists_table = write_txn.open_table(LISTS_TABLE)?; + + let result = match types_table.get(key)? { + Some(type_val) if type_val.value() == "list" => { + if let Some(data) = lists_table.get(key)? { + let decrypted = self.decrypt_if_needed(data.value())?; + let list: Vec = serde_json::from_slice(&decrypted)?; + Some(list) + } else { + None + } + } + _ => None, + }; + result + }; + + if let Some(list) = list_data { + if list.is_empty() { + write_txn.commit()?; + return Ok(()); + } + + let len = list.len() as i64; + let start_idx = if start < 0 { std::cmp::max(0, len + start) } else { std::cmp::min(start, len) }; + let stop_idx = if stop < 0 { std::cmp::max(-1, len + stop) } else { std::cmp::min(stop, len - 1) }; + + let mut lists_table = write_txn.open_table(LISTS_TABLE)?; + if start_idx > stop_idx || start_idx >= len { + // Remove the entire list + lists_table.remove(key)?; + let mut types_table = write_txn.open_table(TYPES_TABLE)?; + types_table.remove(key)?; + } else { + let start_usize = start_idx as usize; + let stop_usize = (stop_idx + 1) as usize; + let trimmed = list[start_usize..std::cmp::min(stop_usize, list.len())].to_vec(); + + if trimmed.is_empty() { + lists_table.remove(key)?; + let mut types_table = write_txn.open_table(TYPES_TABLE)?; + types_table.remove(key)?; + } else { + // Encrypt and store the trimmed list + let serialized = serde_json::to_vec(&trimmed)?; + let encrypted = self.encrypt_if_needed(&serialized)?; + lists_table.insert(key, encrypted.as_slice())?; + } + } + } + + write_txn.commit()?; + Ok(()) + } + + // ✅ ENCRYPTION APPLIED: Elements are decrypted after retrieval and encrypted before storage + pub fn lrem(&self, key: &str, count: i64, element: &str) -> Result { + let write_txn = self.db.begin_write()?; + let mut removed = 0i64; + + // First check if key exists and is a list, and get the data + let list_data = { + let types_table = write_txn.open_table(TYPES_TABLE)?; + let lists_table = write_txn.open_table(LISTS_TABLE)?; + + let result = match types_table.get(key)? { + Some(type_val) if type_val.value() == "list" => { + if let Some(data) = lists_table.get(key)? { + let decrypted = self.decrypt_if_needed(data.value())?; + let list: Vec = serde_json::from_slice(&decrypted)?; + Some(list) + } else { + None + } + } + _ => None, + }; + result + }; + + if let Some(mut list) = list_data { + if count == 0 { + // Remove all occurrences + let original_len = list.len(); + list.retain(|x| x != element); + removed = (original_len - list.len()) as i64; + } else if count > 0 { + // Remove first count occurrences + let mut to_remove = count as usize; + list.retain(|x| { + if x == element && to_remove > 0 { + to_remove -= 1; + removed += 1; + false + } else { + true + } + }); + } else { + // Remove last |count| occurrences + let mut to_remove = (-count) as usize; + for i in (0..list.len()).rev() { + if list[i] == element && to_remove > 0 { + list.remove(i); + to_remove -= 1; + removed += 1; + } + } + } + + let mut lists_table = write_txn.open_table(LISTS_TABLE)?; + if list.is_empty() { + lists_table.remove(key)?; + let mut types_table = write_txn.open_table(TYPES_TABLE)?; + types_table.remove(key)?; + } else { + // Encrypt and store the updated list + let serialized = serde_json::to_vec(&list)?; + let encrypted = self.encrypt_if_needed(&serialized)?; + lists_table.insert(key, encrypted.as_slice())?; + } + } + + write_txn.commit()?; + Ok(removed) + } +} \ No newline at end of file