use std::{ path::Path, sync::Arc, time::{SystemTime, UNIX_EPOCH}, }; use redb::{Database, TableDefinition}; use serde::{Deserialize, Serialize}; use crate::crypto::CryptoFactory; use crate::error::DBError; // Re-export modules mod storage_basic; mod storage_hset; mod storage_lists; mod storage_extra; // Re-export implementations // Note: These imports are used by the impl blocks in the submodules // The compiler shows them as unused because they're not directly used in this file // but they're needed for the Storage struct methods to be available pub use storage_extra::*; // Table definitions for different Redis data types const TYPES_TABLE: TableDefinition<&str, &str> = TableDefinition::new("types"); const STRINGS_TABLE: TableDefinition<&str, &[u8]> = TableDefinition::new("strings"); const HASHES_TABLE: TableDefinition<(&str, &str), &[u8]> = TableDefinition::new("hashes"); const LISTS_TABLE: TableDefinition<&str, &[u8]> = TableDefinition::new("lists"); const STREAMS_META_TABLE: TableDefinition<&str, &[u8]> = TableDefinition::new("streams_meta"); const STREAMS_DATA_TABLE: TableDefinition<(&str, &str), &[u8]> = TableDefinition::new("streams_data"); const ENCRYPTED_TABLE: TableDefinition<&str, u8> = TableDefinition::new("encrypted"); const EXPIRATION_TABLE: TableDefinition<&str, u64> = TableDefinition::new("expiration"); #[derive(Serialize, Deserialize, Debug, Clone)] pub struct StreamEntry { pub fields: Vec<(String, String)>, } #[derive(Serialize, Deserialize, Debug, Clone)] pub struct ListValue { pub elements: Vec, } #[inline] pub fn now_in_millis() -> u128 { let start = SystemTime::now(); let duration_since_epoch = start.duration_since(UNIX_EPOCH).unwrap(); duration_since_epoch.as_millis() } pub struct Storage { db: Database, crypto: Option, } impl Storage { pub fn new(path: impl AsRef, should_encrypt: bool, master_key: Option<&str>) -> Result { let db = Database::create(path)?; // Create tables if they don't exist let write_txn = db.begin_write()?; { let _ = write_txn.open_table(TYPES_TABLE)?; let _ = write_txn.open_table(STRINGS_TABLE)?; let _ = write_txn.open_table(HASHES_TABLE)?; let _ = write_txn.open_table(LISTS_TABLE)?; let _ = write_txn.open_table(STREAMS_META_TABLE)?; let _ = write_txn.open_table(STREAMS_DATA_TABLE)?; let _ = write_txn.open_table(ENCRYPTED_TABLE)?; let _ = write_txn.open_table(EXPIRATION_TABLE)?; } write_txn.commit()?; // Check if database was previously encrypted let read_txn = db.begin_read()?; let encrypted_table = read_txn.open_table(ENCRYPTED_TABLE)?; let was_encrypted = encrypted_table.get("encrypted")?.map(|v| v.value() == 1).unwrap_or(false); drop(read_txn); let crypto = if should_encrypt || was_encrypted { if let Some(key) = master_key { Some(CryptoFactory::new(key.as_bytes())) } else { return Err(DBError("Encryption requested but no master key provided".to_string())); } } else { None }; // If we're enabling encryption for the first time, mark it if should_encrypt && !was_encrypted { let write_txn = db.begin_write()?; { let mut encrypted_table = write_txn.open_table(ENCRYPTED_TABLE)?; encrypted_table.insert("encrypted", &1u8)?; } write_txn.commit()?; } Ok(Storage { db, crypto, }) } pub fn is_encrypted(&self) -> bool { self.crypto.is_some() } // Helper methods for encryption fn encrypt_if_needed(&self, data: &[u8]) -> Result, DBError> { if let Some(crypto) = &self.crypto { Ok(crypto.encrypt(data)) } else { Ok(data.to_vec()) } } fn decrypt_if_needed(&self, data: &[u8]) -> Result, DBError> { if let Some(crypto) = &self.crypto { Ok(crypto.decrypt(data)?) } else { Ok(data.to_vec()) } } } use crate::storage_trait::StorageBackend; impl StorageBackend for Storage { fn get(&self, key: &str) -> Result, DBError> { self.get(key) } fn set(&self, key: String, value: String) -> Result<(), DBError> { self.set(key, value) } fn setx(&self, key: String, value: String, expire_ms: u128) -> Result<(), DBError> { self.setx(key, value, expire_ms) } fn del(&self, key: String) -> Result<(), DBError> { self.del(key) } fn exists(&self, key: &str) -> Result { self.exists(key) } fn keys(&self, pattern: &str) -> Result, DBError> { self.keys(pattern) } fn dbsize(&self) -> Result { self.dbsize() } fn flushdb(&self) -> Result<(), DBError> { self.flushdb() } fn get_key_type(&self, key: &str) -> Result, DBError> { self.get_key_type(key) } fn scan(&self, cursor: u64, pattern: Option<&str>, count: Option) -> Result<(u64, Vec<(String, String)>), DBError> { self.scan(cursor, pattern, count) } fn hscan(&self, key: &str, cursor: u64, pattern: Option<&str>, count: Option) -> Result<(u64, Vec<(String, String)>), DBError> { self.hscan(key, cursor, pattern, count) } fn hset(&self, key: &str, pairs: Vec<(String, String)>) -> Result { self.hset(key, pairs) } fn hget(&self, key: &str, field: &str) -> Result, DBError> { self.hget(key, field) } fn hgetall(&self, key: &str) -> Result, DBError> { self.hgetall(key) } fn hdel(&self, key: &str, fields: Vec) -> Result { self.hdel(key, fields) } fn hexists(&self, key: &str, field: &str) -> Result { self.hexists(key, field) } fn hkeys(&self, key: &str) -> Result, DBError> { self.hkeys(key) } fn hvals(&self, key: &str) -> Result, DBError> { self.hvals(key) } fn hlen(&self, key: &str) -> Result { self.hlen(key) } fn hmget(&self, key: &str, fields: Vec) -> Result>, DBError> { self.hmget(key, fields) } fn hsetnx(&self, key: &str, field: &str, value: &str) -> Result { self.hsetnx(key, field, value) } fn lpush(&self, key: &str, elements: Vec) -> Result { self.lpush(key, elements) } fn rpush(&self, key: &str, elements: Vec) -> Result { self.rpush(key, elements) } fn lpop(&self, key: &str, count: u64) -> Result, DBError> { self.lpop(key, count) } fn rpop(&self, key: &str, count: u64) -> Result, DBError> { self.rpop(key, count) } fn llen(&self, key: &str) -> Result { self.llen(key) } fn lindex(&self, key: &str, index: i64) -> Result, DBError> { self.lindex(key, index) } fn lrange(&self, key: &str, start: i64, stop: i64) -> Result, DBError> { self.lrange(key, start, stop) } fn ltrim(&self, key: &str, start: i64, stop: i64) -> Result<(), DBError> { self.ltrim(key, start, stop) } fn lrem(&self, key: &str, count: i64, element: &str) -> Result { self.lrem(key, count, element) } fn ttl(&self, key: &str) -> Result { self.ttl(key) } fn expire_seconds(&self, key: &str, secs: u64) -> Result { self.expire_seconds(key, secs) } fn pexpire_millis(&self, key: &str, ms: u128) -> Result { self.pexpire_millis(key, ms) } fn persist(&self, key: &str) -> Result { self.persist(key) } fn expire_at_seconds(&self, key: &str, ts_secs: i64) -> Result { self.expire_at_seconds(key, ts_secs) } fn pexpire_at_millis(&self, key: &str, ts_ms: i64) -> Result { self.pexpire_at_millis(key, ts_ms) } fn is_encrypted(&self) -> bool { self.is_encrypted() } fn info(&self) -> Result, DBError> { self.info() } fn clone_arc(&self) -> Arc { unimplemented!("Storage cloning not yet implemented for redb backend") } }