287 lines
8.6 KiB
Rust
287 lines
8.6 KiB
Rust
use std::{
|
|
path::Path,
|
|
sync::Arc,
|
|
time::{SystemTime, UNIX_EPOCH},
|
|
};
|
|
|
|
use redb::{Database, TableDefinition};
|
|
use serde::{Deserialize, Serialize};
|
|
|
|
use crate::crypto::CryptoFactory;
|
|
use crate::error::DBError;
|
|
|
|
// Re-export modules
|
|
mod storage_basic;
|
|
mod storage_hset;
|
|
mod storage_lists;
|
|
mod storage_extra;
|
|
|
|
// Re-export implementations
|
|
// Note: These imports are used by the impl blocks in the submodules
|
|
// The compiler shows them as unused because they're not directly used in this file
|
|
// but they're needed for the Storage struct methods to be available
|
|
pub use storage_extra::*;
|
|
|
|
// Table definitions for different Redis data types
|
|
const TYPES_TABLE: TableDefinition<&str, &str> = TableDefinition::new("types");
|
|
const STRINGS_TABLE: TableDefinition<&str, &[u8]> = TableDefinition::new("strings");
|
|
const HASHES_TABLE: TableDefinition<(&str, &str), &[u8]> = TableDefinition::new("hashes");
|
|
const LISTS_TABLE: TableDefinition<&str, &[u8]> = TableDefinition::new("lists");
|
|
const STREAMS_META_TABLE: TableDefinition<&str, &[u8]> = TableDefinition::new("streams_meta");
|
|
const STREAMS_DATA_TABLE: TableDefinition<(&str, &str), &[u8]> = TableDefinition::new("streams_data");
|
|
const ENCRYPTED_TABLE: TableDefinition<&str, u8> = TableDefinition::new("encrypted");
|
|
const EXPIRATION_TABLE: TableDefinition<&str, u64> = TableDefinition::new("expiration");
|
|
|
|
#[derive(Serialize, Deserialize, Debug, Clone)]
|
|
pub struct StreamEntry {
|
|
pub fields: Vec<(String, String)>,
|
|
}
|
|
|
|
#[derive(Serialize, Deserialize, Debug, Clone)]
|
|
pub struct ListValue {
|
|
pub elements: Vec<String>,
|
|
}
|
|
|
|
#[inline]
|
|
pub fn now_in_millis() -> u128 {
|
|
let start = SystemTime::now();
|
|
let duration_since_epoch = start.duration_since(UNIX_EPOCH).unwrap();
|
|
duration_since_epoch.as_millis()
|
|
}
|
|
|
|
pub struct Storage {
|
|
db: Database,
|
|
crypto: Option<CryptoFactory>,
|
|
}
|
|
|
|
impl Storage {
|
|
pub fn new(path: impl AsRef<Path>, should_encrypt: bool, master_key: Option<&str>) -> Result<Self, DBError> {
|
|
let db = Database::create(path)?;
|
|
|
|
// Create tables if they don't exist
|
|
let write_txn = db.begin_write()?;
|
|
{
|
|
let _ = write_txn.open_table(TYPES_TABLE)?;
|
|
let _ = write_txn.open_table(STRINGS_TABLE)?;
|
|
let _ = write_txn.open_table(HASHES_TABLE)?;
|
|
let _ = write_txn.open_table(LISTS_TABLE)?;
|
|
let _ = write_txn.open_table(STREAMS_META_TABLE)?;
|
|
let _ = write_txn.open_table(STREAMS_DATA_TABLE)?;
|
|
let _ = write_txn.open_table(ENCRYPTED_TABLE)?;
|
|
let _ = write_txn.open_table(EXPIRATION_TABLE)?;
|
|
}
|
|
write_txn.commit()?;
|
|
|
|
// Check if database was previously encrypted
|
|
let read_txn = db.begin_read()?;
|
|
let encrypted_table = read_txn.open_table(ENCRYPTED_TABLE)?;
|
|
let was_encrypted = encrypted_table.get("encrypted")?.map(|v| v.value() == 1).unwrap_or(false);
|
|
drop(read_txn);
|
|
|
|
let crypto = if should_encrypt || was_encrypted {
|
|
if let Some(key) = master_key {
|
|
Some(CryptoFactory::new(key.as_bytes()))
|
|
} else {
|
|
return Err(DBError("Encryption requested but no master key provided".to_string()));
|
|
}
|
|
} else {
|
|
None
|
|
};
|
|
|
|
// If we're enabling encryption for the first time, mark it
|
|
if should_encrypt && !was_encrypted {
|
|
let write_txn = db.begin_write()?;
|
|
{
|
|
let mut encrypted_table = write_txn.open_table(ENCRYPTED_TABLE)?;
|
|
encrypted_table.insert("encrypted", &1u8)?;
|
|
}
|
|
write_txn.commit()?;
|
|
}
|
|
|
|
Ok(Storage {
|
|
db,
|
|
crypto,
|
|
})
|
|
}
|
|
|
|
pub fn is_encrypted(&self) -> bool {
|
|
self.crypto.is_some()
|
|
}
|
|
|
|
// Helper methods for encryption
|
|
fn encrypt_if_needed(&self, data: &[u8]) -> Result<Vec<u8>, DBError> {
|
|
if let Some(crypto) = &self.crypto {
|
|
Ok(crypto.encrypt(data))
|
|
} else {
|
|
Ok(data.to_vec())
|
|
}
|
|
}
|
|
|
|
fn decrypt_if_needed(&self, data: &[u8]) -> Result<Vec<u8>, DBError> {
|
|
if let Some(crypto) = &self.crypto {
|
|
Ok(crypto.decrypt(data)?)
|
|
} else {
|
|
Ok(data.to_vec())
|
|
}
|
|
}
|
|
}
|
|
|
|
use crate::storage_trait::StorageBackend;
|
|
|
|
impl StorageBackend for Storage {
|
|
fn get(&self, key: &str) -> Result<Option<String>, DBError> {
|
|
self.get(key)
|
|
}
|
|
|
|
fn set(&self, key: String, value: String) -> Result<(), DBError> {
|
|
self.set(key, value)
|
|
}
|
|
|
|
fn setx(&self, key: String, value: String, expire_ms: u128) -> Result<(), DBError> {
|
|
self.setx(key, value, expire_ms)
|
|
}
|
|
|
|
fn del(&self, key: String) -> Result<(), DBError> {
|
|
self.del(key)
|
|
}
|
|
|
|
fn exists(&self, key: &str) -> Result<bool, DBError> {
|
|
self.exists(key)
|
|
}
|
|
|
|
fn keys(&self, pattern: &str) -> Result<Vec<String>, DBError> {
|
|
self.keys(pattern)
|
|
}
|
|
|
|
fn dbsize(&self) -> Result<i64, DBError> {
|
|
self.dbsize()
|
|
}
|
|
|
|
fn flushdb(&self) -> Result<(), DBError> {
|
|
self.flushdb()
|
|
}
|
|
|
|
fn get_key_type(&self, key: &str) -> Result<Option<String>, DBError> {
|
|
self.get_key_type(key)
|
|
}
|
|
|
|
fn scan(&self, cursor: u64, pattern: Option<&str>, count: Option<u64>) -> Result<(u64, Vec<(String, String)>), DBError> {
|
|
self.scan(cursor, pattern, count)
|
|
}
|
|
|
|
fn hscan(&self, key: &str, cursor: u64, pattern: Option<&str>, count: Option<u64>) -> Result<(u64, Vec<(String, String)>), DBError> {
|
|
self.hscan(key, cursor, pattern, count)
|
|
}
|
|
|
|
fn hset(&self, key: &str, pairs: Vec<(String, String)>) -> Result<i64, DBError> {
|
|
self.hset(key, pairs)
|
|
}
|
|
|
|
fn hget(&self, key: &str, field: &str) -> Result<Option<String>, DBError> {
|
|
self.hget(key, field)
|
|
}
|
|
|
|
fn hgetall(&self, key: &str) -> Result<Vec<(String, String)>, DBError> {
|
|
self.hgetall(key)
|
|
}
|
|
|
|
fn hdel(&self, key: &str, fields: Vec<String>) -> Result<i64, DBError> {
|
|
self.hdel(key, fields)
|
|
}
|
|
|
|
fn hexists(&self, key: &str, field: &str) -> Result<bool, DBError> {
|
|
self.hexists(key, field)
|
|
}
|
|
|
|
fn hkeys(&self, key: &str) -> Result<Vec<String>, DBError> {
|
|
self.hkeys(key)
|
|
}
|
|
|
|
fn hvals(&self, key: &str) -> Result<Vec<String>, DBError> {
|
|
self.hvals(key)
|
|
}
|
|
|
|
fn hlen(&self, key: &str) -> Result<i64, DBError> {
|
|
self.hlen(key)
|
|
}
|
|
|
|
fn hmget(&self, key: &str, fields: Vec<String>) -> Result<Vec<Option<String>>, DBError> {
|
|
self.hmget(key, fields)
|
|
}
|
|
|
|
fn hsetnx(&self, key: &str, field: &str, value: &str) -> Result<bool, DBError> {
|
|
self.hsetnx(key, field, value)
|
|
}
|
|
|
|
fn lpush(&self, key: &str, elements: Vec<String>) -> Result<i64, DBError> {
|
|
self.lpush(key, elements)
|
|
}
|
|
|
|
fn rpush(&self, key: &str, elements: Vec<String>) -> Result<i64, DBError> {
|
|
self.rpush(key, elements)
|
|
}
|
|
|
|
fn lpop(&self, key: &str, count: u64) -> Result<Vec<String>, DBError> {
|
|
self.lpop(key, count)
|
|
}
|
|
|
|
fn rpop(&self, key: &str, count: u64) -> Result<Vec<String>, DBError> {
|
|
self.rpop(key, count)
|
|
}
|
|
|
|
fn llen(&self, key: &str) -> Result<i64, DBError> {
|
|
self.llen(key)
|
|
}
|
|
|
|
fn lindex(&self, key: &str, index: i64) -> Result<Option<String>, DBError> {
|
|
self.lindex(key, index)
|
|
}
|
|
|
|
fn lrange(&self, key: &str, start: i64, stop: i64) -> Result<Vec<String>, DBError> {
|
|
self.lrange(key, start, stop)
|
|
}
|
|
|
|
fn ltrim(&self, key: &str, start: i64, stop: i64) -> Result<(), DBError> {
|
|
self.ltrim(key, start, stop)
|
|
}
|
|
|
|
fn lrem(&self, key: &str, count: i64, element: &str) -> Result<i64, DBError> {
|
|
self.lrem(key, count, element)
|
|
}
|
|
|
|
fn ttl(&self, key: &str) -> Result<i64, DBError> {
|
|
self.ttl(key)
|
|
}
|
|
|
|
fn expire_seconds(&self, key: &str, secs: u64) -> Result<bool, DBError> {
|
|
self.expire_seconds(key, secs)
|
|
}
|
|
|
|
fn pexpire_millis(&self, key: &str, ms: u128) -> Result<bool, DBError> {
|
|
self.pexpire_millis(key, ms)
|
|
}
|
|
|
|
fn persist(&self, key: &str) -> Result<bool, DBError> {
|
|
self.persist(key)
|
|
}
|
|
|
|
fn expire_at_seconds(&self, key: &str, ts_secs: i64) -> Result<bool, DBError> {
|
|
self.expire_at_seconds(key, ts_secs)
|
|
}
|
|
|
|
fn pexpire_at_millis(&self, key: &str, ts_ms: i64) -> Result<bool, DBError> {
|
|
self.pexpire_at_millis(key, ts_ms)
|
|
}
|
|
|
|
fn is_encrypted(&self) -> bool {
|
|
self.is_encrypted()
|
|
}
|
|
|
|
fn info(&self) -> Result<Vec<(String, String)>, DBError> {
|
|
self.info()
|
|
}
|
|
|
|
fn clone_arc(&self) -> Arc<dyn StorageBackend> {
|
|
unimplemented!("Storage cloning not yet implemented for redb backend")
|
|
}
|
|
} |