// src/storage_sled/mod.rs use std::path::Path; use std::sync::Arc; use std::collections::HashMap; use std::time::{SystemTime, UNIX_EPOCH}; use serde::{Deserialize, Serialize}; use crate::error::DBError; use crate::storage_trait::StorageBackend; use crate::crypto::CryptoFactory; #[derive(Serialize, Deserialize, Debug, Clone)] enum ValueType { String(String), Hash(HashMap), List(Vec), } #[derive(Serialize, Deserialize, Debug, Clone)] struct StorageValue { value: ValueType, expires_at: Option, // milliseconds since epoch } pub struct SledStorage { db: sled::Db, types: sled::Tree, crypto: Option, } impl SledStorage { pub fn new(path: impl AsRef, should_encrypt: bool, master_key: Option<&str>) -> Result { let db = sled::open(path).map_err(|e| DBError(format!("Failed to open sled: {}", e)))?; let types = db.open_tree("types").map_err(|e| DBError(format!("Failed to open types tree: {}", e)))?; // Check if database was previously encrypted let encrypted_tree = db.open_tree("encrypted").map_err(|e| DBError(e.to_string()))?; let was_encrypted = encrypted_tree.get("encrypted") .map_err(|e| DBError(e.to_string()))? .map(|v| v[0] == 1) .unwrap_or(false); let crypto = if should_encrypt || was_encrypted { if let Some(key) = master_key { Some(CryptoFactory::new(key.as_bytes())) } else { return Err(DBError("Encryption requested but no master key provided".to_string())); } } else { None }; // Mark database as encrypted if enabling encryption if should_encrypt && !was_encrypted { encrypted_tree.insert("encrypted", &[1u8]) .map_err(|e| DBError(e.to_string()))?; encrypted_tree.flush().map_err(|e| DBError(e.to_string()))?; } Ok(SledStorage { db, types, crypto }) } fn now_millis() -> u128 { SystemTime::now() .duration_since(UNIX_EPOCH) .unwrap() .as_millis() } fn encrypt_if_needed(&self, data: &[u8]) -> Result, DBError> { if let Some(crypto) = &self.crypto { Ok(crypto.encrypt(data)) } else { Ok(data.to_vec()) } } fn decrypt_if_needed(&self, data: &[u8]) -> Result, DBError> { if let Some(crypto) = &self.crypto { Ok(crypto.decrypt(data)?) } else { Ok(data.to_vec()) } } fn get_storage_value(&self, key: &str) -> Result, DBError> { match self.db.get(key).map_err(|e| DBError(e.to_string()))? { Some(encrypted_data) => { let decrypted = self.decrypt_if_needed(&encrypted_data)?; let storage_val: StorageValue = bincode::deserialize(&decrypted) .map_err(|e| DBError(format!("Deserialization error: {}", e)))?; // Check expiration if let Some(expires_at) = storage_val.expires_at { if Self::now_millis() > expires_at { // Expired, remove it self.db.remove(key).map_err(|e| DBError(e.to_string()))?; self.types.remove(key).map_err(|e| DBError(e.to_string()))?; return Ok(None); } } Ok(Some(storage_val)) } None => Ok(None) } } fn set_storage_value(&self, key: &str, storage_val: StorageValue) -> Result<(), DBError> { let data = bincode::serialize(&storage_val) .map_err(|e| DBError(format!("Serialization error: {}", e)))?; let encrypted = self.encrypt_if_needed(&data)?; self.db.insert(key, encrypted).map_err(|e| DBError(e.to_string()))?; // Store type info (unencrypted for efficiency) let type_str = match &storage_val.value { ValueType::String(_) => "string", ValueType::Hash(_) => "hash", ValueType::List(_) => "list", }; self.types.insert(key, type_str.as_bytes()).map_err(|e| DBError(e.to_string()))?; Ok(()) } fn glob_match(pattern: &str, text: &str) -> bool { if pattern == "*" { return true; } let pattern_chars: Vec = pattern.chars().collect(); let text_chars: Vec = text.chars().collect(); fn match_recursive(pattern: &[char], text: &[char], pi: usize, ti: usize) -> bool { if pi >= pattern.len() { return ti >= text.len(); } if ti >= text.len() { return pattern[pi..].iter().all(|&c| c == '*'); } match pattern[pi] { '*' => { for i in ti..=text.len() { if match_recursive(pattern, text, pi + 1, i) { return true; } } false } '?' => match_recursive(pattern, text, pi + 1, ti + 1), c => { if text[ti] == c { match_recursive(pattern, text, pi + 1, ti + 1) } else { false } } } } match_recursive(&pattern_chars, &text_chars, 0, 0) } } impl StorageBackend for SledStorage { fn get(&self, key: &str) -> Result, DBError> { match self.get_storage_value(key)? { Some(storage_val) => match storage_val.value { ValueType::String(s) => Ok(Some(s)), _ => Ok(None) } None => Ok(None) } } fn set(&self, key: String, value: String) -> Result<(), DBError> { let storage_val = StorageValue { value: ValueType::String(value), expires_at: None, }; self.set_storage_value(&key, storage_val)?; self.db.flush().map_err(|e| DBError(e.to_string()))?; Ok(()) } fn setx(&self, key: String, value: String, expire_ms: u128) -> Result<(), DBError> { let storage_val = StorageValue { value: ValueType::String(value), expires_at: Some(Self::now_millis() + expire_ms), }; self.set_storage_value(&key, storage_val)?; self.db.flush().map_err(|e| DBError(e.to_string()))?; Ok(()) } fn del(&self, key: String) -> Result<(), DBError> { self.db.remove(&key).map_err(|e| DBError(e.to_string()))?; self.types.remove(&key).map_err(|e| DBError(e.to_string()))?; self.db.flush().map_err(|e| DBError(e.to_string()))?; Ok(()) } fn exists(&self, key: &str) -> Result { // Check with expiration Ok(self.get_storage_value(key)?.is_some()) } fn keys(&self, pattern: &str) -> Result, DBError> { let mut keys = Vec::new(); for item in self.types.iter() { let (key_bytes, _) = item.map_err(|e| DBError(e.to_string()))?; let key = String::from_utf8_lossy(&key_bytes).to_string(); // Check if key is expired if self.get_storage_value(&key)?.is_some() { if Self::glob_match(pattern, &key) { keys.push(key); } } } Ok(keys) } fn scan(&self, cursor: u64, pattern: Option<&str>, count: Option) -> Result<(u64, Vec<(String, String)>), DBError> { let mut result = Vec::new(); let mut current_cursor = 0u64; let limit = count.unwrap_or(10) as usize; for item in self.types.iter() { if current_cursor >= cursor { let (key_bytes, type_bytes) = item.map_err(|e| DBError(e.to_string()))?; let key = String::from_utf8_lossy(&key_bytes).to_string(); // Check pattern match let matches = if let Some(pat) = pattern { Self::glob_match(pat, &key) } else { true }; if matches { // Check if key is expired and get value if let Some(storage_val) = self.get_storage_value(&key)? { let value = match storage_val.value { ValueType::String(s) => s, _ => String::from_utf8_lossy(&type_bytes).to_string(), }; result.push((key, value)); if result.len() >= limit { current_cursor += 1; break; } } } } current_cursor += 1; } let next_cursor = if result.len() < limit { 0 } else { current_cursor }; Ok((next_cursor, result)) } fn dbsize(&self) -> Result { let mut count = 0i64; for item in self.types.iter() { let (key_bytes, _) = item.map_err(|e| DBError(e.to_string()))?; let key = String::from_utf8_lossy(&key_bytes).to_string(); if self.get_storage_value(&key)?.is_some() { count += 1; } } Ok(count) } fn flushdb(&self) -> Result<(), DBError> { self.db.clear().map_err(|e| DBError(e.to_string()))?; self.types.clear().map_err(|e| DBError(e.to_string()))?; self.db.flush().map_err(|e| DBError(e.to_string()))?; Ok(()) } fn get_key_type(&self, key: &str) -> Result, DBError> { // First check if key exists (handles expiration) if self.get_storage_value(key)?.is_some() { match self.types.get(key).map_err(|e| DBError(e.to_string()))? { Some(data) => Ok(Some(String::from_utf8_lossy(&data).to_string())), None => Ok(None) } } else { Ok(None) } } // Hash operations fn hset(&self, key: &str, pairs: Vec<(String, String)>) -> Result { let mut storage_val = self.get_storage_value(key)?.unwrap_or(StorageValue { value: ValueType::Hash(HashMap::new()), expires_at: None, }); let hash = match &mut storage_val.value { ValueType::Hash(h) => h, _ => return Err(DBError("WRONGTYPE Operation against a key holding the wrong kind of value".to_string())), }; let mut new_fields = 0i64; for (field, value) in pairs { if !hash.contains_key(&field) { new_fields += 1; } hash.insert(field, value); } self.set_storage_value(key, storage_val)?; self.db.flush().map_err(|e| DBError(e.to_string()))?; Ok(new_fields) } fn hget(&self, key: &str, field: &str) -> Result, DBError> { match self.get_storage_value(key)? { Some(storage_val) => match storage_val.value { ValueType::Hash(h) => Ok(h.get(field).cloned()), _ => Ok(None) } None => Ok(None) } } fn hgetall(&self, key: &str) -> Result, DBError> { match self.get_storage_value(key)? { Some(storage_val) => match storage_val.value { ValueType::Hash(h) => Ok(h.into_iter().collect()), _ => Ok(Vec::new()) } None => Ok(Vec::new()) } } fn hscan(&self, key: &str, cursor: u64, pattern: Option<&str>, count: Option) -> Result<(u64, Vec<(String, String)>), DBError> { match self.get_storage_value(key)? { Some(storage_val) => match storage_val.value { ValueType::Hash(h) => { let mut result = Vec::new(); let mut current_cursor = 0u64; let limit = count.unwrap_or(10) as usize; for (field, value) in h.iter() { if current_cursor >= cursor { let matches = if let Some(pat) = pattern { Self::glob_match(pat, field) } else { true }; if matches { result.push((field.clone(), value.clone())); if result.len() >= limit { current_cursor += 1; break; } } } current_cursor += 1; } let next_cursor = if result.len() < limit { 0 } else { current_cursor }; Ok((next_cursor, result)) } _ => Err(DBError("WRONGTYPE Operation against a key holding the wrong kind of value".to_string())) } None => Ok((0, Vec::new())) } } fn hdel(&self, key: &str, fields: Vec) -> Result { let mut storage_val = match self.get_storage_value(key)? { Some(sv) => sv, None => return Ok(0) }; let hash = match &mut storage_val.value { ValueType::Hash(h) => h, _ => return Ok(0) }; let mut deleted = 0i64; for field in fields { if hash.remove(&field).is_some() { deleted += 1; } } if hash.is_empty() { self.del(key.to_string())?; } else { self.set_storage_value(key, storage_val)?; self.db.flush().map_err(|e| DBError(e.to_string()))?; } Ok(deleted) } fn hexists(&self, key: &str, field: &str) -> Result { match self.get_storage_value(key)? { Some(storage_val) => match storage_val.value { ValueType::Hash(h) => Ok(h.contains_key(field)), _ => Ok(false) } None => Ok(false) } } fn hkeys(&self, key: &str) -> Result, DBError> { match self.get_storage_value(key)? { Some(storage_val) => match storage_val.value { ValueType::Hash(h) => Ok(h.keys().cloned().collect()), _ => Ok(Vec::new()) } None => Ok(Vec::new()) } } fn hvals(&self, key: &str) -> Result, DBError> { match self.get_storage_value(key)? { Some(storage_val) => match storage_val.value { ValueType::Hash(h) => Ok(h.values().cloned().collect()), _ => Ok(Vec::new()) } None => Ok(Vec::new()) } } fn hlen(&self, key: &str) -> Result { match self.get_storage_value(key)? { Some(storage_val) => match storage_val.value { ValueType::Hash(h) => Ok(h.len() as i64), _ => Ok(0) } None => Ok(0) } } fn hmget(&self, key: &str, fields: Vec) -> Result>, DBError> { match self.get_storage_value(key)? { Some(storage_val) => match storage_val.value { ValueType::Hash(h) => { Ok(fields.into_iter().map(|f| h.get(&f).cloned()).collect()) } _ => Ok(fields.into_iter().map(|_| None).collect()) } None => Ok(fields.into_iter().map(|_| None).collect()) } } fn hsetnx(&self, key: &str, field: &str, value: &str) -> Result { let mut storage_val = self.get_storage_value(key)?.unwrap_or(StorageValue { value: ValueType::Hash(HashMap::new()), expires_at: None, }); let hash = match &mut storage_val.value { ValueType::Hash(h) => h, _ => return Err(DBError("WRONGTYPE Operation against a key holding the wrong kind of value".to_string())), }; if hash.contains_key(field) { Ok(false) } else { hash.insert(field.to_string(), value.to_string()); self.set_storage_value(key, storage_val)?; self.db.flush().map_err(|e| DBError(e.to_string()))?; Ok(true) } } // List operations fn lpush(&self, key: &str, elements: Vec) -> Result { let mut storage_val = self.get_storage_value(key)?.unwrap_or(StorageValue { value: ValueType::List(Vec::new()), expires_at: None, }); let list = match &mut storage_val.value { ValueType::List(l) => l, _ => return Err(DBError("WRONGTYPE Operation against a key holding the wrong kind of value".to_string())), }; for element in elements.into_iter().rev() { list.insert(0, element); } let len = list.len() as i64; self.set_storage_value(key, storage_val)?; self.db.flush().map_err(|e| DBError(e.to_string()))?; Ok(len) } fn rpush(&self, key: &str, elements: Vec) -> Result { let mut storage_val = self.get_storage_value(key)?.unwrap_or(StorageValue { value: ValueType::List(Vec::new()), expires_at: None, }); let list = match &mut storage_val.value { ValueType::List(l) => l, _ => return Err(DBError("WRONGTYPE Operation against a key holding the wrong kind of value".to_string())), }; list.extend(elements); let len = list.len() as i64; self.set_storage_value(key, storage_val)?; self.db.flush().map_err(|e| DBError(e.to_string()))?; Ok(len) } fn lpop(&self, key: &str, count: u64) -> Result, DBError> { let mut storage_val = match self.get_storage_value(key)? { Some(sv) => sv, None => return Ok(Vec::new()) }; let list = match &mut storage_val.value { ValueType::List(l) => l, _ => return Ok(Vec::new()) }; let mut result = Vec::new(); for _ in 0..count.min(list.len() as u64) { if let Some(elem) = list.first() { result.push(elem.clone()); list.remove(0); } } if list.is_empty() { self.del(key.to_string())?; } else { self.set_storage_value(key, storage_val)?; self.db.flush().map_err(|e| DBError(e.to_string()))?; } Ok(result) } fn rpop(&self, key: &str, count: u64) -> Result, DBError> { let mut storage_val = match self.get_storage_value(key)? { Some(sv) => sv, None => return Ok(Vec::new()) }; let list = match &mut storage_val.value { ValueType::List(l) => l, _ => return Ok(Vec::new()) }; let mut result = Vec::new(); for _ in 0..count.min(list.len() as u64) { if let Some(elem) = list.pop() { result.push(elem); } } if list.is_empty() { self.del(key.to_string())?; } else { self.set_storage_value(key, storage_val)?; self.db.flush().map_err(|e| DBError(e.to_string()))?; } Ok(result) } fn llen(&self, key: &str) -> Result { match self.get_storage_value(key)? { Some(storage_val) => match storage_val.value { ValueType::List(l) => Ok(l.len() as i64), _ => Ok(0) } None => Ok(0) } } fn lindex(&self, key: &str, index: i64) -> Result, DBError> { match self.get_storage_value(key)? { Some(storage_val) => match storage_val.value { ValueType::List(list) => { let actual_index = if index < 0 { list.len() as i64 + index } else { index }; if actual_index >= 0 && (actual_index as usize) < list.len() { Ok(Some(list[actual_index as usize].clone())) } else { Ok(None) } } _ => Ok(None) } None => Ok(None) } } fn lrange(&self, key: &str, start: i64, stop: i64) -> Result, DBError> { match self.get_storage_value(key)? { Some(storage_val) => match storage_val.value { ValueType::List(list) => { if list.is_empty() { return Ok(Vec::new()); } let len = list.len() as i64; let start_idx = if start < 0 { std::cmp::max(0, len + start) } else { std::cmp::min(start, len) }; let stop_idx = if stop < 0 { std::cmp::max(-1, len + stop) } else { std::cmp::min(stop, len - 1) }; if start_idx > stop_idx || start_idx >= len { return Ok(Vec::new()); } let start_usize = start_idx as usize; let stop_usize = (stop_idx + 1) as usize; Ok(list[start_usize..std::cmp::min(stop_usize, list.len())].to_vec()) } _ => Ok(Vec::new()) } None => Ok(Vec::new()) } } fn ltrim(&self, key: &str, start: i64, stop: i64) -> Result<(), DBError> { let mut storage_val = match self.get_storage_value(key)? { Some(sv) => sv, None => return Ok(()) }; let list = match &mut storage_val.value { ValueType::List(l) => l, _ => return Ok(()) }; if list.is_empty() { return Ok(()); } let len = list.len() as i64; let start_idx = if start < 0 { std::cmp::max(0, len + start) } else { std::cmp::min(start, len) }; let stop_idx = if stop < 0 { std::cmp::max(-1, len + stop) } else { std::cmp::min(stop, len - 1) }; if start_idx > stop_idx || start_idx >= len { self.del(key.to_string())?; } else { let start_usize = start_idx as usize; let stop_usize = (stop_idx + 1) as usize; *list = list[start_usize..std::cmp::min(stop_usize, list.len())].to_vec(); if list.is_empty() { self.del(key.to_string())?; } else { self.set_storage_value(key, storage_val)?; self.db.flush().map_err(|e| DBError(e.to_string()))?; } } Ok(()) } fn lrem(&self, key: &str, count: i64, element: &str) -> Result { let mut storage_val = match self.get_storage_value(key)? { Some(sv) => sv, None => return Ok(0) }; let list = match &mut storage_val.value { ValueType::List(l) => l, _ => return Ok(0) }; let mut removed = 0i64; if count == 0 { // Remove all occurrences let original_len = list.len(); list.retain(|x| x != element); removed = (original_len - list.len()) as i64; } else if count > 0 { // Remove first count occurrences let mut to_remove = count as usize; list.retain(|x| { if x == element && to_remove > 0 { to_remove -= 1; removed += 1; false } else { true } }); } else { // Remove last |count| occurrences let mut to_remove = (-count) as usize; for i in (0..list.len()).rev() { if list[i] == element && to_remove > 0 { list.remove(i); to_remove -= 1; removed += 1; } } } if list.is_empty() { self.del(key.to_string())?; } else { self.set_storage_value(key, storage_val)?; self.db.flush().map_err(|e| DBError(e.to_string()))?; } Ok(removed) } // Expiration fn ttl(&self, key: &str) -> Result { match self.get_storage_value(key)? { Some(storage_val) => { if let Some(expires_at) = storage_val.expires_at { let now = Self::now_millis(); if now >= expires_at { Ok(-2) // Key has expired } else { Ok(((expires_at - now) / 1000) as i64) // TTL in seconds } } else { Ok(-1) // Key exists but has no expiration } } None => Ok(-2) // Key does not exist } } fn expire_seconds(&self, key: &str, secs: u64) -> Result { let mut storage_val = match self.get_storage_value(key)? { Some(sv) => sv, None => return Ok(false) }; storage_val.expires_at = Some(Self::now_millis() + (secs as u128) * 1000); self.set_storage_value(key, storage_val)?; self.db.flush().map_err(|e| DBError(e.to_string()))?; Ok(true) } fn pexpire_millis(&self, key: &str, ms: u128) -> Result { let mut storage_val = match self.get_storage_value(key)? { Some(sv) => sv, None => return Ok(false) }; storage_val.expires_at = Some(Self::now_millis() + ms); self.set_storage_value(key, storage_val)?; self.db.flush().map_err(|e| DBError(e.to_string()))?; Ok(true) } fn persist(&self, key: &str) -> Result { let mut storage_val = match self.get_storage_value(key)? { Some(sv) => sv, None => return Ok(false) }; if storage_val.expires_at.is_some() { storage_val.expires_at = None; self.set_storage_value(key, storage_val)?; self.db.flush().map_err(|e| DBError(e.to_string()))?; Ok(true) } else { Ok(false) } } fn expire_at_seconds(&self, key: &str, ts_secs: i64) -> Result { let mut storage_val = match self.get_storage_value(key)? { Some(sv) => sv, None => return Ok(false) }; let expires_at_ms: u128 = if ts_secs <= 0 { 0 } else { (ts_secs as u128) * 1000 }; storage_val.expires_at = Some(expires_at_ms); self.set_storage_value(key, storage_val)?; self.db.flush().map_err(|e| DBError(e.to_string()))?; Ok(true) } fn pexpire_at_millis(&self, key: &str, ts_ms: i64) -> Result { let mut storage_val = match self.get_storage_value(key)? { Some(sv) => sv, None => return Ok(false) }; let expires_at_ms: u128 = if ts_ms <= 0 { 0 } else { ts_ms as u128 }; storage_val.expires_at = Some(expires_at_ms); self.set_storage_value(key, storage_val)?; self.db.flush().map_err(|e| DBError(e.to_string()))?; Ok(true) } fn is_encrypted(&self) -> bool { self.crypto.is_some() } fn info(&self) -> Result, DBError> { let dbsize = self.dbsize()?; Ok(vec![ ("db_size".to_string(), dbsize.to_string()), ("is_encrypted".to_string(), self.is_encrypted().to_string()), ]) } fn clone_arc(&self) -> Arc { // Note: This is a simplified clone - in production you might want to // handle this differently as sled::Db is already Arc internally Arc::new(SledStorage { db: self.db.clone(), types: self.types.clone(), crypto: self.crypto.clone(), }) } }