...
This commit is contained in:
308
src/age.rs
Normal file
308
src/age.rs
Normal file
@@ -0,0 +1,308 @@
|
||||
//! age.rs — AGE (rage) helpers + persistent key management for your mini-Redis.
|
||||
//
|
||||
// Features:
|
||||
// - X25519 encryption/decryption (age style)
|
||||
// - Ed25519 detached signatures + verification
|
||||
// - Persistent named keys in DB (strings):
|
||||
// age:key:{name} -> X25519 recipient (public encryption key, "age1...")
|
||||
// age:privkey:{name} -> X25519 identity (secret encryption key, "AGE-SECRET-KEY-1...")
|
||||
// age:signpub:{name} -> Ed25519 verify pubkey (public, used to verify signatures)
|
||||
// age:signpriv:{name} -> Ed25519 signing secret key (private, used to sign)
|
||||
// - Base64 wrapping for ciphertext/signature binary blobs.
|
||||
|
||||
use std::str::FromStr;
|
||||
|
||||
use secrecy::ExposeSecret;
|
||||
use age::{Decryptor, Encryptor};
|
||||
use age::x25519;
|
||||
|
||||
use ed25519_dalek::{Signature, Signer, Verifier, SigningKey, VerifyingKey};
|
||||
|
||||
use base64::{engine::general_purpose::STANDARD as B64, Engine as _};
|
||||
|
||||
use crate::protocol::Protocol;
|
||||
use crate::server::Server;
|
||||
use crate::error::DBError;
|
||||
|
||||
// ---------- Internal helpers ----------
|
||||
|
||||
#[derive(Debug)]
|
||||
pub enum AgeWireError {
|
||||
ParseKey,
|
||||
Crypto(String),
|
||||
Utf8,
|
||||
SignatureLen,
|
||||
NotFound(&'static str), // which kind of key was missing
|
||||
Storage(String),
|
||||
}
|
||||
|
||||
impl AgeWireError {
|
||||
fn to_protocol(self) -> Protocol {
|
||||
match self {
|
||||
AgeWireError::ParseKey => Protocol::err("ERR age: invalid key"),
|
||||
AgeWireError::Crypto(e) => Protocol::err(&format!("ERR age: {e}")),
|
||||
AgeWireError::Utf8 => Protocol::err("ERR age: invalid UTF-8 plaintext"),
|
||||
AgeWireError::SignatureLen => Protocol::err("ERR age: bad signature length"),
|
||||
AgeWireError::NotFound(w) => Protocol::err(&format!("ERR age: missing {w}")),
|
||||
AgeWireError::Storage(e) => Protocol::err(&format!("ERR storage: {e}")),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
fn parse_recipient(s: &str) -> Result<x25519::Recipient, AgeWireError> {
|
||||
x25519::Recipient::from_str(s).map_err(|_| AgeWireError::ParseKey)
|
||||
}
|
||||
fn parse_identity(s: &str) -> Result<x25519::Identity, AgeWireError> {
|
||||
x25519::Identity::from_str(s).map_err(|_| AgeWireError::ParseKey)
|
||||
}
|
||||
fn parse_ed25519_signing_key(s: &str) -> Result<SigningKey, AgeWireError> {
|
||||
// Parse base64-encoded signing key
|
||||
let bytes = B64.decode(s).map_err(|_| AgeWireError::ParseKey)?;
|
||||
if bytes.len() != 32 {
|
||||
return Err(AgeWireError::ParseKey);
|
||||
}
|
||||
let key_bytes: [u8; 32] = bytes.try_into().map_err(|_| AgeWireError::ParseKey)?;
|
||||
Ok(SigningKey::from_bytes(&key_bytes))
|
||||
}
|
||||
fn parse_ed25519_verifying_key(s: &str) -> Result<VerifyingKey, AgeWireError> {
|
||||
// Parse base64-encoded verifying key
|
||||
let bytes = B64.decode(s).map_err(|_| AgeWireError::ParseKey)?;
|
||||
if bytes.len() != 32 {
|
||||
return Err(AgeWireError::ParseKey);
|
||||
}
|
||||
let key_bytes: [u8; 32] = bytes.try_into().map_err(|_| AgeWireError::ParseKey)?;
|
||||
VerifyingKey::from_bytes(&key_bytes).map_err(|_| AgeWireError::ParseKey)
|
||||
}
|
||||
|
||||
// ---------- Stateless crypto helpers (string in/out) ----------
|
||||
|
||||
pub fn gen_enc_keypair() -> (String, String) {
|
||||
let id = x25519::Identity::generate();
|
||||
let pk = id.to_public();
|
||||
(pk.to_string(), id.to_string().expose_secret().to_string()) // (recipient, identity)
|
||||
}
|
||||
|
||||
pub fn gen_sign_keypair() -> (String, String) {
|
||||
use rand::RngCore;
|
||||
use rand::rngs::OsRng;
|
||||
|
||||
// Generate random 32 bytes for the signing key
|
||||
let mut secret_bytes = [0u8; 32];
|
||||
OsRng.fill_bytes(&mut secret_bytes);
|
||||
|
||||
let signing_key = SigningKey::from_bytes(&secret_bytes);
|
||||
let verifying_key = signing_key.verifying_key();
|
||||
|
||||
// Encode as base64 for storage
|
||||
let signing_key_b64 = B64.encode(signing_key.to_bytes());
|
||||
let verifying_key_b64 = B64.encode(verifying_key.to_bytes());
|
||||
|
||||
(verifying_key_b64, signing_key_b64) // (verify_pub, signing_secret)
|
||||
}
|
||||
|
||||
/// Encrypt `msg` for `recipient_str` (X25519). Returns base64(ciphertext).
|
||||
pub fn encrypt_b64(recipient_str: &str, msg: &str) -> Result<String, AgeWireError> {
|
||||
let recipient = parse_recipient(recipient_str)?;
|
||||
let enc = Encryptor::with_recipients(vec![Box::new(recipient)])
|
||||
.expect("failed to create encryptor"); // Handle Option<Encryptor>
|
||||
let mut out = Vec::new();
|
||||
{
|
||||
use std::io::Write;
|
||||
let mut w = enc.wrap_output(&mut out).map_err(|e| AgeWireError::Crypto(e.to_string()))?;
|
||||
w.write_all(msg.as_bytes()).map_err(|e| AgeWireError::Crypto(e.to_string()))?;
|
||||
w.finish().map_err(|e| AgeWireError::Crypto(e.to_string()))?;
|
||||
}
|
||||
Ok(B64.encode(out))
|
||||
}
|
||||
|
||||
/// Decrypt base64(ciphertext) with `identity_str`. Returns plaintext String.
|
||||
pub fn decrypt_b64(identity_str: &str, ct_b64: &str) -> Result<String, AgeWireError> {
|
||||
let id = parse_identity(identity_str)?;
|
||||
let ct = B64.decode(ct_b64.as_bytes()).map_err(|e| AgeWireError::Crypto(e.to_string()))?;
|
||||
let dec = Decryptor::new(&ct[..]).map_err(|e| AgeWireError::Crypto(e.to_string()))?;
|
||||
|
||||
// The decrypt method returns a Result<StreamReader, DecryptError>
|
||||
let mut r = match dec {
|
||||
Decryptor::Recipients(d) => d.decrypt(std::iter::once(&id as &dyn age::Identity))
|
||||
.map_err(|e| AgeWireError::Crypto(e.to_string()))?,
|
||||
Decryptor::Passphrase(_) => return Err(AgeWireError::Crypto("Expected recipients, got passphrase".to_string())),
|
||||
};
|
||||
|
||||
let mut pt = Vec::new();
|
||||
use std::io::Read;
|
||||
r.read_to_end(&mut pt).map_err(|e| AgeWireError::Crypto(e.to_string()))?;
|
||||
String::from_utf8(pt).map_err(|_| AgeWireError::Utf8)
|
||||
}
|
||||
|
||||
/// Sign bytes of `msg` (detached). Returns base64(signature bytes, 64 bytes).
|
||||
pub fn sign_b64(signing_secret_str: &str, msg: &str) -> Result<String, AgeWireError> {
|
||||
let signing_key = parse_ed25519_signing_key(signing_secret_str)?;
|
||||
let sig = signing_key.sign(msg.as_bytes());
|
||||
Ok(B64.encode(sig.to_bytes()))
|
||||
}
|
||||
|
||||
/// Verify detached signature (base64) for `msg` with pubkey.
|
||||
pub fn verify_b64(verify_pub_str: &str, msg: &str, sig_b64: &str) -> Result<bool, AgeWireError> {
|
||||
let verifying_key = parse_ed25519_verifying_key(verify_pub_str)?;
|
||||
let sig_bytes = B64.decode(sig_b64.as_bytes()).map_err(|e| AgeWireError::Crypto(e.to_string()))?;
|
||||
if sig_bytes.len() != 64 {
|
||||
return Err(AgeWireError::SignatureLen);
|
||||
}
|
||||
let sig = Signature::from_bytes(sig_bytes[..].try_into().unwrap());
|
||||
Ok(verifying_key.verify(msg.as_bytes(), &sig).is_ok())
|
||||
}
|
||||
|
||||
// ---------- Storage helpers ----------
|
||||
|
||||
fn sget(server: &Server, key: &str) -> Result<Option<String>, AgeWireError> {
|
||||
let st = server.current_storage().map_err(|e| AgeWireError::Storage(e.0))?;
|
||||
st.get(key).map_err(|e| AgeWireError::Storage(e.0))
|
||||
}
|
||||
fn sset(server: &Server, key: &str, val: &str) -> Result<(), AgeWireError> {
|
||||
let st = server.current_storage().map_err(|e| AgeWireError::Storage(e.0))?;
|
||||
st.set(key.to_string(), val.to_string()).map_err(|e| AgeWireError::Storage(e.0))
|
||||
}
|
||||
|
||||
fn enc_pub_key_key(name: &str) -> String { format!("age:key:{name}") }
|
||||
fn enc_priv_key_key(name: &str) -> String { format!("age:privkey:{name}") }
|
||||
fn sign_pub_key_key(name: &str) -> String { format!("age:signpub:{name}") }
|
||||
fn sign_priv_key_key(name: &str) -> String { format!("age:signpriv:{name}") }
|
||||
|
||||
// ---------- Command handlers (RESP Protocol) ----------
|
||||
// Basic (stateless) ones kept for completeness
|
||||
|
||||
pub async fn cmd_age_genenc() -> Protocol {
|
||||
let (recip, ident) = gen_enc_keypair();
|
||||
Protocol::Array(vec![Protocol::BulkString(recip), Protocol::BulkString(ident)])
|
||||
}
|
||||
|
||||
pub async fn cmd_age_gensign() -> Protocol {
|
||||
let (verify, secret) = gen_sign_keypair();
|
||||
Protocol::Array(vec![Protocol::BulkString(verify), Protocol::BulkString(secret)])
|
||||
}
|
||||
|
||||
pub async fn cmd_age_encrypt(recipient: &str, message: &str) -> Protocol {
|
||||
match encrypt_b64(recipient, message) {
|
||||
Ok(b64) => Protocol::BulkString(b64),
|
||||
Err(e) => e.to_protocol(),
|
||||
}
|
||||
}
|
||||
|
||||
pub async fn cmd_age_decrypt(identity: &str, ct_b64: &str) -> Protocol {
|
||||
match decrypt_b64(identity, ct_b64) {
|
||||
Ok(pt) => Protocol::BulkString(pt),
|
||||
Err(e) => e.to_protocol(),
|
||||
}
|
||||
}
|
||||
|
||||
pub async fn cmd_age_sign(secret: &str, message: &str) -> Protocol {
|
||||
match sign_b64(secret, message) {
|
||||
Ok(b64sig) => Protocol::BulkString(b64sig),
|
||||
Err(e) => e.to_protocol(),
|
||||
}
|
||||
}
|
||||
|
||||
pub async fn cmd_age_verify(verify_pub: &str, message: &str, sig_b64: &str) -> Protocol {
|
||||
match verify_b64(verify_pub, message, sig_b64) {
|
||||
Ok(true) => Protocol::SimpleString("1".to_string()),
|
||||
Ok(false) => Protocol::SimpleString("0".to_string()),
|
||||
Err(e) => e.to_protocol(),
|
||||
}
|
||||
}
|
||||
|
||||
// ---------- NEW: Persistent, named-key commands ----------
|
||||
|
||||
pub async fn cmd_age_keygen(server: &Server, name: &str) -> Protocol {
|
||||
let (recip, ident) = gen_enc_keypair();
|
||||
if let Err(e) = sset(server, &enc_pub_key_key(name), &recip) { return e.to_protocol(); }
|
||||
if let Err(e) = sset(server, &enc_priv_key_key(name), &ident) { return e.to_protocol(); }
|
||||
Protocol::Array(vec![Protocol::BulkString(recip), Protocol::BulkString(ident)])
|
||||
}
|
||||
|
||||
pub async fn cmd_age_signkeygen(server: &Server, name: &str) -> Protocol {
|
||||
let (verify, secret) = gen_sign_keypair();
|
||||
if let Err(e) = sset(server, &sign_pub_key_key(name), &verify) { return e.to_protocol(); }
|
||||
if let Err(e) = sset(server, &sign_priv_key_key(name), &secret) { return e.to_protocol(); }
|
||||
Protocol::Array(vec![Protocol::BulkString(verify), Protocol::BulkString(secret)])
|
||||
}
|
||||
|
||||
pub async fn cmd_age_encrypt_name(server: &Server, name: &str, message: &str) -> Protocol {
|
||||
let recip = match sget(server, &enc_pub_key_key(name)) {
|
||||
Ok(Some(v)) => v,
|
||||
Ok(None) => return AgeWireError::NotFound("recipient (age:key:{name})").to_protocol(),
|
||||
Err(e) => return e.to_protocol(),
|
||||
};
|
||||
match encrypt_b64(&recip, message) {
|
||||
Ok(ct) => Protocol::BulkString(ct),
|
||||
Err(e) => e.to_protocol(),
|
||||
}
|
||||
}
|
||||
|
||||
pub async fn cmd_age_decrypt_name(server: &Server, name: &str, ct_b64: &str) -> Protocol {
|
||||
let ident = match sget(server, &enc_priv_key_key(name)) {
|
||||
Ok(Some(v)) => v,
|
||||
Ok(None) => return AgeWireError::NotFound("identity (age:privkey:{name})").to_protocol(),
|
||||
Err(e) => return e.to_protocol(),
|
||||
};
|
||||
match decrypt_b64(&ident, ct_b64) {
|
||||
Ok(pt) => Protocol::BulkString(pt),
|
||||
Err(e) => e.to_protocol(),
|
||||
}
|
||||
}
|
||||
|
||||
pub async fn cmd_age_sign_name(server: &Server, name: &str, message: &str) -> Protocol {
|
||||
let sec = match sget(server, &sign_priv_key_key(name)) {
|
||||
Ok(Some(v)) => v,
|
||||
Ok(None) => return AgeWireError::NotFound("signing secret (age:signpriv:{name})").to_protocol(),
|
||||
Err(e) => return e.to_protocol(),
|
||||
};
|
||||
match sign_b64(&sec, message) {
|
||||
Ok(sig) => Protocol::BulkString(sig),
|
||||
Err(e) => e.to_protocol(),
|
||||
}
|
||||
}
|
||||
|
||||
pub async fn cmd_age_verify_name(server: &Server, name: &str, message: &str, sig_b64: &str) -> Protocol {
|
||||
let pubk = match sget(server, &sign_pub_key_key(name)) {
|
||||
Ok(Some(v)) => v,
|
||||
Ok(None) => return AgeWireError::NotFound("verify pubkey (age:signpub:{name})").to_protocol(),
|
||||
Err(e) => return e.to_protocol(),
|
||||
};
|
||||
match verify_b64(&pubk, message, sig_b64) {
|
||||
Ok(true) => Protocol::SimpleString("1".to_string()),
|
||||
Ok(false) => Protocol::SimpleString("0".to_string()),
|
||||
Err(e) => e.to_protocol(),
|
||||
}
|
||||
}
|
||||
|
||||
pub async fn cmd_age_list(server: &Server) -> Protocol {
|
||||
// Returns 4 arrays: ["encpub", <names...>], ["encpriv", ...], ["signpub", ...], ["signpriv", ...]
|
||||
let st = match server.current_storage() { Ok(s) => s, Err(e) => return Protocol::err(&e.0) };
|
||||
|
||||
let pull = |pat: &str, prefix: &str| -> Result<Vec<String>, DBError> {
|
||||
let keys = st.keys(pat)?;
|
||||
let mut names: Vec<String> = keys.into_iter()
|
||||
.filter_map(|k| k.strip_prefix(prefix).map(|x| x.to_string()))
|
||||
.collect();
|
||||
names.sort();
|
||||
Ok(names)
|
||||
};
|
||||
|
||||
let encpub = match pull("age:key:*", "age:key:") { Ok(v) => v, Err(e)=> return Protocol::err(&e.0) };
|
||||
let encpriv = match pull("age:privkey:*", "age:privkey:") { Ok(v) => v, Err(e)=> return Protocol::err(&e.0) };
|
||||
let signpub = match pull("age:signpub:*", "age:signpub:") { Ok(v) => v, Err(e)=> return Protocol::err(&e.0) };
|
||||
let signpriv= match pull("age:signpriv:*", "age:signpriv:") { Ok(v) => v, Err(e)=> return Protocol::err(&e.0) };
|
||||
|
||||
let to_arr = |label: &str, v: Vec<String>| {
|
||||
let mut out = vec![Protocol::BulkString(label.to_string())];
|
||||
out.push(Protocol::Array(v.into_iter().map(Protocol::BulkString).collect()));
|
||||
Protocol::Array(out)
|
||||
};
|
||||
|
||||
Protocol::Array(vec![
|
||||
to_arr("encpub", encpub),
|
||||
to_arr("encpriv", encpriv),
|
||||
to_arr("signpub", signpub),
|
||||
to_arr("signpriv", signpriv),
|
||||
])
|
||||
}
|
1721
src/cmd.rs
Normal file
1721
src/cmd.rs
Normal file
File diff suppressed because it is too large
Load Diff
74
src/crypto.rs
Normal file
74
src/crypto.rs
Normal file
@@ -0,0 +1,74 @@
|
||||
use chacha20poly1305::{
|
||||
aead::{Aead, KeyInit, OsRng},
|
||||
XChaCha20Poly1305, XNonce,
|
||||
};
|
||||
use rand::RngCore;
|
||||
use sha2::{Digest, Sha256};
|
||||
|
||||
const VERSION: u8 = 1;
|
||||
const NONCE_LEN: usize = 24;
|
||||
const TAG_LEN: usize = 16;
|
||||
|
||||
#[derive(Debug)]
|
||||
pub enum CryptoError {
|
||||
Format, // wrong length / header
|
||||
Version(u8), // unknown version
|
||||
Decrypt, // wrong key or corrupted data
|
||||
}
|
||||
|
||||
impl From<CryptoError> for crate::error::DBError {
|
||||
fn from(e: CryptoError) -> Self {
|
||||
crate::error::DBError(format!("Crypto error: {:?}", e))
|
||||
}
|
||||
}
|
||||
|
||||
/// Super-simple factory: new(secret) + encrypt(bytes) + decrypt(bytes)
|
||||
#[derive(Clone)]
|
||||
pub struct CryptoFactory {
|
||||
key: chacha20poly1305::Key,
|
||||
}
|
||||
|
||||
impl CryptoFactory {
|
||||
/// Accepts any secret bytes; turns them into a 32-byte key (SHA-256).
|
||||
pub fn new<S: AsRef<[u8]>>(secret: S) -> Self {
|
||||
let mut h = Sha256::new();
|
||||
h.update(b"xchacha20poly1305-factory:v1"); // domain separation
|
||||
h.update(secret.as_ref());
|
||||
let digest = h.finalize(); // 32 bytes
|
||||
let key = chacha20poly1305::Key::from_slice(&digest).to_owned();
|
||||
Self { key }
|
||||
}
|
||||
|
||||
/// Output layout: [version:1][nonce:24][ciphertext||tag]
|
||||
pub fn encrypt(&self, plaintext: &[u8]) -> Vec<u8> {
|
||||
let cipher = XChaCha20Poly1305::new(&self.key);
|
||||
|
||||
let mut nonce_bytes = [0u8; NONCE_LEN];
|
||||
OsRng.fill_bytes(&mut nonce_bytes);
|
||||
let nonce = XNonce::from_slice(&nonce_bytes);
|
||||
|
||||
let mut out = Vec::with_capacity(1 + NONCE_LEN + plaintext.len() + TAG_LEN);
|
||||
out.push(VERSION);
|
||||
out.extend_from_slice(&nonce_bytes);
|
||||
|
||||
let ct = cipher.encrypt(nonce, plaintext).expect("encrypt");
|
||||
out.extend_from_slice(&ct);
|
||||
out
|
||||
}
|
||||
|
||||
pub fn decrypt(&self, blob: &[u8]) -> Result<Vec<u8>, CryptoError> {
|
||||
if blob.len() < 1 + NONCE_LEN + TAG_LEN {
|
||||
return Err(CryptoError::Format);
|
||||
}
|
||||
let ver = blob[0];
|
||||
if ver != VERSION {
|
||||
return Err(CryptoError::Version(ver));
|
||||
}
|
||||
|
||||
let nonce = XNonce::from_slice(&blob[1..1 + NONCE_LEN]);
|
||||
let ct = &blob[1 + NONCE_LEN..];
|
||||
|
||||
let cipher = XChaCha20Poly1305::new(&self.key);
|
||||
cipher.decrypt(nonce, ct).map_err(|_| CryptoError::Decrypt)
|
||||
}
|
||||
}
|
94
src/error.rs
Normal file
94
src/error.rs
Normal file
@@ -0,0 +1,94 @@
|
||||
use std::num::ParseIntError;
|
||||
|
||||
use tokio::sync::mpsc;
|
||||
use redb;
|
||||
use bincode;
|
||||
|
||||
|
||||
// todo: more error types
|
||||
#[derive(Debug)]
|
||||
pub struct DBError(pub String);
|
||||
|
||||
impl From<std::io::Error> for DBError {
|
||||
fn from(item: std::io::Error) -> Self {
|
||||
DBError(item.to_string().clone())
|
||||
}
|
||||
}
|
||||
|
||||
impl From<ParseIntError> for DBError {
|
||||
fn from(item: ParseIntError) -> Self {
|
||||
DBError(item.to_string().clone())
|
||||
}
|
||||
}
|
||||
|
||||
impl From<std::str::Utf8Error> for DBError {
|
||||
fn from(item: std::str::Utf8Error) -> Self {
|
||||
DBError(item.to_string().clone())
|
||||
}
|
||||
}
|
||||
|
||||
impl From<std::string::FromUtf8Error> for DBError {
|
||||
fn from(item: std::string::FromUtf8Error) -> Self {
|
||||
DBError(item.to_string().clone())
|
||||
}
|
||||
}
|
||||
|
||||
impl From<redb::Error> for DBError {
|
||||
fn from(item: redb::Error) -> Self {
|
||||
DBError(item.to_string())
|
||||
}
|
||||
}
|
||||
|
||||
impl From<redb::DatabaseError> for DBError {
|
||||
fn from(item: redb::DatabaseError) -> Self {
|
||||
DBError(item.to_string())
|
||||
}
|
||||
}
|
||||
|
||||
impl From<redb::TransactionError> for DBError {
|
||||
fn from(item: redb::TransactionError) -> Self {
|
||||
DBError(item.to_string())
|
||||
}
|
||||
}
|
||||
|
||||
impl From<redb::TableError> for DBError {
|
||||
fn from(item: redb::TableError) -> Self {
|
||||
DBError(item.to_string())
|
||||
}
|
||||
}
|
||||
|
||||
impl From<redb::StorageError> for DBError {
|
||||
fn from(item: redb::StorageError) -> Self {
|
||||
DBError(item.to_string())
|
||||
}
|
||||
}
|
||||
|
||||
impl From<redb::CommitError> for DBError {
|
||||
fn from(item: redb::CommitError) -> Self {
|
||||
DBError(item.to_string())
|
||||
}
|
||||
}
|
||||
|
||||
impl From<Box<bincode::ErrorKind>> for DBError {
|
||||
fn from(item: Box<bincode::ErrorKind>) -> Self {
|
||||
DBError(item.to_string())
|
||||
}
|
||||
}
|
||||
|
||||
impl From<tokio::sync::mpsc::error::SendError<()>> for DBError {
|
||||
fn from(item: mpsc::error::SendError<()>) -> Self {
|
||||
DBError(item.to_string().clone())
|
||||
}
|
||||
}
|
||||
|
||||
impl From<serde_json::Error> for DBError {
|
||||
fn from(item: serde_json::Error) -> Self {
|
||||
DBError(item.to_string())
|
||||
}
|
||||
}
|
||||
|
||||
impl From<chacha20poly1305::Error> for DBError {
|
||||
fn from(item: chacha20poly1305::Error) -> Self {
|
||||
DBError(item.to_string())
|
||||
}
|
||||
}
|
12
src/lib.rs
Normal file
12
src/lib.rs
Normal file
@@ -0,0 +1,12 @@
|
||||
pub mod age; // NEW
|
||||
pub mod cmd;
|
||||
pub mod crypto;
|
||||
pub mod error;
|
||||
pub mod options;
|
||||
pub mod protocol;
|
||||
pub mod search_cmd; // Add this
|
||||
pub mod server;
|
||||
pub mod storage;
|
||||
pub mod storage_trait; // Add this
|
||||
pub mod storage_sled; // Add this
|
||||
pub mod tantivy_search;
|
90
src/main.rs
Normal file
90
src/main.rs
Normal file
@@ -0,0 +1,90 @@
|
||||
// #![allow(unused_imports)]
|
||||
|
||||
use tokio::net::TcpListener;
|
||||
|
||||
use herodb::server;
|
||||
|
||||
use clap::Parser;
|
||||
|
||||
/// Simple program to greet a person
|
||||
#[derive(Parser, Debug)]
|
||||
#[command(version, about, long_about = None)]
|
||||
struct Args {
|
||||
/// The directory of Redis DB file
|
||||
#[arg(long)]
|
||||
dir: String,
|
||||
|
||||
/// The port of the Redis server, default is 6379 if not specified
|
||||
#[arg(long)]
|
||||
port: Option<u16>,
|
||||
|
||||
/// Enable debug mode
|
||||
#[arg(long)]
|
||||
debug: bool,
|
||||
|
||||
|
||||
/// Master encryption key for encrypted databases
|
||||
#[arg(long)]
|
||||
encryption_key: Option<String>,
|
||||
|
||||
/// Encrypt the database
|
||||
#[arg(long)]
|
||||
encrypt: bool,
|
||||
|
||||
/// Use the sled backend
|
||||
#[arg(long)]
|
||||
sled: bool,
|
||||
}
|
||||
|
||||
#[tokio::main]
|
||||
async fn main() {
|
||||
// parse args
|
||||
let args = Args::parse();
|
||||
|
||||
// bind port
|
||||
let port = args.port.unwrap_or(6379);
|
||||
println!("will listen on port: {}", port);
|
||||
let listener = TcpListener::bind(format!("127.0.0.1:{}", port))
|
||||
.await
|
||||
.unwrap();
|
||||
|
||||
// new DB option
|
||||
let option = herodb::options::DBOption {
|
||||
dir: args.dir,
|
||||
port,
|
||||
debug: args.debug,
|
||||
encryption_key: args.encryption_key,
|
||||
encrypt: args.encrypt,
|
||||
backend: if args.sled {
|
||||
herodb::options::BackendType::Sled
|
||||
} else {
|
||||
herodb::options::BackendType::Redb
|
||||
},
|
||||
};
|
||||
|
||||
// new server
|
||||
let server = server::Server::new(option).await;
|
||||
|
||||
// Add a small delay to ensure the port is ready
|
||||
tokio::time::sleep(std::time::Duration::from_millis(100)).await;
|
||||
|
||||
// accept new connections
|
||||
loop {
|
||||
let stream = listener.accept().await;
|
||||
match stream {
|
||||
Ok((stream, _)) => {
|
||||
println!("accepted new connection");
|
||||
|
||||
let mut sc = server.clone();
|
||||
tokio::spawn(async move {
|
||||
if let Err(e) = sc.handle(stream).await {
|
||||
println!("error: {:?}, will close the connection. Bye", e);
|
||||
}
|
||||
});
|
||||
}
|
||||
Err(e) => {
|
||||
println!("error: {}", e);
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
15
src/options.rs
Normal file
15
src/options.rs
Normal file
@@ -0,0 +1,15 @@
|
||||
#[derive(Debug, Clone)]
|
||||
pub enum BackendType {
|
||||
Redb,
|
||||
Sled,
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone)]
|
||||
pub struct DBOption {
|
||||
pub dir: String,
|
||||
pub port: u16,
|
||||
pub debug: bool,
|
||||
pub encrypt: bool,
|
||||
pub encryption_key: Option<String>,
|
||||
pub backend: BackendType,
|
||||
}
|
171
src/protocol.rs
Normal file
171
src/protocol.rs
Normal file
@@ -0,0 +1,171 @@
|
||||
use core::fmt;
|
||||
|
||||
use crate::error::DBError;
|
||||
|
||||
#[derive(Debug, Clone)]
|
||||
pub enum Protocol {
|
||||
SimpleString(String),
|
||||
BulkString(String),
|
||||
Null,
|
||||
Array(Vec<Protocol>),
|
||||
Error(String), // NEW
|
||||
}
|
||||
|
||||
impl fmt::Display for Protocol {
|
||||
fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
|
||||
write!(f, "{}", self.decode().as_str())
|
||||
}
|
||||
}
|
||||
|
||||
impl Protocol {
|
||||
pub fn from(protocol: &str) -> Result<(Self, &str), DBError> {
|
||||
if protocol.is_empty() {
|
||||
// Incomplete frame; caller should read more bytes
|
||||
return Err(DBError("[incomplete] empty".to_string()));
|
||||
}
|
||||
let ret = match protocol.chars().nth(0) {
|
||||
Some('+') => Self::parse_simple_string_sfx(&protocol[1..]),
|
||||
Some('$') => Self::parse_bulk_string_sfx(&protocol[1..]),
|
||||
Some('*') => Self::parse_array_sfx(&protocol[1..]),
|
||||
_ => Err(DBError(format!(
|
||||
"[from] unsupported protocol: {:?}",
|
||||
protocol
|
||||
))),
|
||||
};
|
||||
ret
|
||||
}
|
||||
|
||||
pub fn from_vec(array: Vec<&str>) -> Self {
|
||||
let array = array
|
||||
.into_iter()
|
||||
.map(|x| Protocol::BulkString(x.to_string()))
|
||||
.collect();
|
||||
Protocol::Array(array)
|
||||
}
|
||||
|
||||
#[inline]
|
||||
pub fn ok() -> Self {
|
||||
Protocol::SimpleString("ok".to_string())
|
||||
}
|
||||
|
||||
#[inline]
|
||||
pub fn err(msg: &str) -> Self {
|
||||
Protocol::Error(msg.to_string())
|
||||
}
|
||||
|
||||
#[inline]
|
||||
pub fn write_on_slave_err() -> Self {
|
||||
Self::err("DISALLOW WRITE ON SLAVE")
|
||||
}
|
||||
|
||||
#[inline]
|
||||
pub fn psync_on_slave_err() -> Self {
|
||||
Self::err("PSYNC ON SLAVE IS NOT ALLOWED")
|
||||
}
|
||||
|
||||
#[inline]
|
||||
pub fn none() -> Self {
|
||||
Self::SimpleString("none".to_string())
|
||||
}
|
||||
|
||||
pub fn decode(&self) -> String {
|
||||
match self {
|
||||
Protocol::SimpleString(s) => s.to_string(),
|
||||
Protocol::BulkString(s) => s.to_string(),
|
||||
Protocol::Null => "".to_string(),
|
||||
Protocol::Array(s) => s.iter().map(|x| x.decode()).collect::<Vec<_>>().join(" "),
|
||||
Protocol::Error(s) => s.to_string(),
|
||||
}
|
||||
}
|
||||
|
||||
pub fn encode(&self) -> String {
|
||||
match self {
|
||||
Protocol::SimpleString(s) => format!("+{}\r\n", s),
|
||||
Protocol::BulkString(s) => format!("${}\r\n{}\r\n", s.len(), s),
|
||||
Protocol::Array(ss) => {
|
||||
format!("*{}\r\n", ss.len()) + &ss.iter().map(|x| x.encode()).collect::<String>()
|
||||
}
|
||||
Protocol::Null => "$-1\r\n".to_string(),
|
||||
Protocol::Error(s) => format!("-{}\r\n", s), // proper RESP error
|
||||
}
|
||||
}
|
||||
|
||||
fn parse_simple_string_sfx(protocol: &str) -> Result<(Self, &str), DBError> {
|
||||
match protocol.find("\r\n") {
|
||||
Some(x) => Ok((Self::SimpleString(protocol[..x].to_string()), &protocol[x + 2..])),
|
||||
_ => Err(DBError(format!(
|
||||
"[new simple string] unsupported protocol: {:?}",
|
||||
protocol
|
||||
))),
|
||||
}
|
||||
}
|
||||
|
||||
fn parse_bulk_string_sfx(protocol: &str) -> Result<(Self, &str), DBError> {
|
||||
if let Some(len_end) = protocol.find("\r\n") {
|
||||
let size = Self::parse_usize(&protocol[..len_end])?;
|
||||
let data_start = len_end + 2;
|
||||
let data_end = data_start + size;
|
||||
|
||||
// If we don't yet have the full bulk payload + trailing CRLF, signal INCOMPLETE
|
||||
if protocol.len() < data_end + 2 {
|
||||
return Err(DBError("[incomplete] bulk body".to_string()));
|
||||
}
|
||||
if &protocol[data_end..data_end + 2] != "\r\n" {
|
||||
return Err(DBError("[incomplete] bulk terminator".to_string()));
|
||||
}
|
||||
|
||||
let s = Self::parse_string(&protocol[data_start..data_end])?;
|
||||
Ok((Protocol::BulkString(s), &protocol[data_end + 2..]))
|
||||
} else {
|
||||
// No CRLF after bulk length header yet
|
||||
Err(DBError("[incomplete] bulk header".to_string()))
|
||||
}
|
||||
}
|
||||
|
||||
fn parse_array_sfx(s: &str) -> Result<(Self, &str), DBError> {
|
||||
if let Some(len_end) = s.find("\r\n") {
|
||||
let array_len = s[..len_end].parse::<usize>()?;
|
||||
let mut remaining = &s[len_end + 2..];
|
||||
let mut vec = vec![];
|
||||
for _ in 0..array_len {
|
||||
match Protocol::from(remaining) {
|
||||
Ok((p, rem)) => {
|
||||
vec.push(p);
|
||||
remaining = rem;
|
||||
}
|
||||
Err(e) => {
|
||||
// Propagate incomplete so caller can read more bytes
|
||||
if e.0.starts_with("[incomplete]") {
|
||||
return Err(e);
|
||||
} else {
|
||||
return Err(e);
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
Ok((Protocol::Array(vec), remaining))
|
||||
} else {
|
||||
// No CRLF after array header yet
|
||||
Err(DBError("[incomplete] array header".to_string()))
|
||||
}
|
||||
}
|
||||
|
||||
fn parse_usize(protocol: &str) -> Result<usize, DBError> {
|
||||
if protocol.is_empty() {
|
||||
Err(DBError("Cannot parse usize from empty string".to_string()))
|
||||
} else {
|
||||
protocol
|
||||
.parse::<usize>()
|
||||
.map_err(|_| DBError(format!("Failed to parse usize from: {}", protocol)))
|
||||
}
|
||||
}
|
||||
|
||||
fn parse_string(protocol: &str) -> Result<String, DBError> {
|
||||
if protocol.is_empty() {
|
||||
// Allow empty strings, but handle appropriately
|
||||
Ok("".to_string())
|
||||
} else {
|
||||
Ok(protocol.to_string())
|
||||
}
|
||||
}
|
||||
}
|
272
src/search_cmd.rs
Normal file
272
src/search_cmd.rs
Normal file
@@ -0,0 +1,272 @@
|
||||
use crate::{
|
||||
error::DBError,
|
||||
protocol::Protocol,
|
||||
server::Server,
|
||||
tantivy_search::{
|
||||
TantivySearch, FieldDef, NumericType, IndexConfig,
|
||||
SearchOptions, Filter, FilterType
|
||||
},
|
||||
};
|
||||
use std::collections::HashMap;
|
||||
use std::sync::Arc;
|
||||
|
||||
pub async fn ft_create_cmd(
|
||||
server: &Server,
|
||||
index_name: String,
|
||||
schema: Vec<(String, String, Vec<String>)>,
|
||||
) -> Result<Protocol, DBError> {
|
||||
// Parse schema into field definitions
|
||||
let mut field_definitions = Vec::new();
|
||||
|
||||
for (field_name, field_type, options) in schema {
|
||||
let field_def = match field_type.to_uppercase().as_str() {
|
||||
"TEXT" => {
|
||||
let mut weight = 1.0;
|
||||
let mut sortable = false;
|
||||
let mut no_index = false;
|
||||
|
||||
for opt in &options {
|
||||
match opt.to_uppercase().as_str() {
|
||||
"WEIGHT" => {
|
||||
// Next option should be the weight value
|
||||
if let Some(idx) = options.iter().position(|x| x == opt) {
|
||||
if idx + 1 < options.len() {
|
||||
weight = options[idx + 1].parse().unwrap_or(1.0);
|
||||
}
|
||||
}
|
||||
}
|
||||
"SORTABLE" => sortable = true,
|
||||
"NOINDEX" => no_index = true,
|
||||
_ => {}
|
||||
}
|
||||
}
|
||||
|
||||
FieldDef::Text {
|
||||
stored: true,
|
||||
indexed: !no_index,
|
||||
tokenized: true,
|
||||
fast: sortable,
|
||||
}
|
||||
}
|
||||
"NUMERIC" => {
|
||||
let mut sortable = false;
|
||||
|
||||
for opt in &options {
|
||||
if opt.to_uppercase() == "SORTABLE" {
|
||||
sortable = true;
|
||||
}
|
||||
}
|
||||
|
||||
FieldDef::Numeric {
|
||||
stored: true,
|
||||
indexed: true,
|
||||
fast: sortable,
|
||||
precision: NumericType::F64,
|
||||
}
|
||||
}
|
||||
"TAG" => {
|
||||
let mut separator = ",".to_string();
|
||||
let mut case_sensitive = false;
|
||||
|
||||
for i in 0..options.len() {
|
||||
match options[i].to_uppercase().as_str() {
|
||||
"SEPARATOR" => {
|
||||
if i + 1 < options.len() {
|
||||
separator = options[i + 1].clone();
|
||||
}
|
||||
}
|
||||
"CASESENSITIVE" => case_sensitive = true,
|
||||
_ => {}
|
||||
}
|
||||
}
|
||||
|
||||
FieldDef::Tag {
|
||||
stored: true,
|
||||
separator,
|
||||
case_sensitive,
|
||||
}
|
||||
}
|
||||
"GEO" => {
|
||||
FieldDef::Geo { stored: true }
|
||||
}
|
||||
_ => {
|
||||
return Err(DBError(format!("Unknown field type: {}", field_type)));
|
||||
}
|
||||
};
|
||||
|
||||
field_definitions.push((field_name, field_def));
|
||||
}
|
||||
|
||||
// Create the search index
|
||||
let search_path = server.search_index_path();
|
||||
let config = IndexConfig::default();
|
||||
|
||||
println!("Creating search index '{}' at path: {:?}", index_name, search_path);
|
||||
println!("Field definitions: {:?}", field_definitions);
|
||||
|
||||
let search_index = TantivySearch::new_with_schema(
|
||||
search_path,
|
||||
index_name.clone(),
|
||||
field_definitions,
|
||||
Some(config),
|
||||
)?;
|
||||
|
||||
println!("Search index '{}' created successfully", index_name);
|
||||
|
||||
// Store in registry
|
||||
let mut indexes = server.search_indexes.write().unwrap();
|
||||
indexes.insert(index_name, Arc::new(search_index));
|
||||
|
||||
Ok(Protocol::SimpleString("OK".to_string()))
|
||||
}
|
||||
|
||||
pub async fn ft_add_cmd(
|
||||
server: &Server,
|
||||
index_name: String,
|
||||
doc_id: String,
|
||||
_score: f64,
|
||||
fields: HashMap<String, String>,
|
||||
) -> Result<Protocol, DBError> {
|
||||
let indexes = server.search_indexes.read().unwrap();
|
||||
|
||||
let search_index = indexes.get(&index_name)
|
||||
.ok_or_else(|| DBError(format!("Index '{}' not found", index_name)))?;
|
||||
|
||||
search_index.add_document_with_fields(&doc_id, fields)?;
|
||||
|
||||
Ok(Protocol::SimpleString("OK".to_string()))
|
||||
}
|
||||
|
||||
pub async fn ft_search_cmd(
|
||||
server: &Server,
|
||||
index_name: String,
|
||||
query: String,
|
||||
filters: Vec<(String, String)>,
|
||||
limit: Option<usize>,
|
||||
offset: Option<usize>,
|
||||
return_fields: Option<Vec<String>>,
|
||||
) -> Result<Protocol, DBError> {
|
||||
let indexes = server.search_indexes.read().unwrap();
|
||||
|
||||
let search_index = indexes.get(&index_name)
|
||||
.ok_or_else(|| DBError(format!("Index '{}' not found", index_name)))?;
|
||||
|
||||
// Convert filters to search filters
|
||||
let search_filters = filters.into_iter().map(|(field, value)| {
|
||||
Filter {
|
||||
field,
|
||||
filter_type: FilterType::Equals(value),
|
||||
}
|
||||
}).collect();
|
||||
|
||||
let options = SearchOptions {
|
||||
limit: limit.unwrap_or(10),
|
||||
offset: offset.unwrap_or(0),
|
||||
filters: search_filters,
|
||||
sort_by: None,
|
||||
return_fields,
|
||||
highlight: false,
|
||||
};
|
||||
|
||||
let results = search_index.search_with_options(&query, options)?;
|
||||
|
||||
// Format results as Redis protocol
|
||||
let mut response = Vec::new();
|
||||
|
||||
// First element is the total count
|
||||
response.push(Protocol::SimpleString(results.total.to_string()));
|
||||
|
||||
// Then each document
|
||||
for doc in results.documents {
|
||||
let mut doc_array = Vec::new();
|
||||
|
||||
// Add document ID if it exists
|
||||
if let Some(id) = doc.fields.get("_id") {
|
||||
doc_array.push(Protocol::BulkString(id.clone()));
|
||||
}
|
||||
|
||||
// Add score
|
||||
doc_array.push(Protocol::BulkString(doc.score.to_string()));
|
||||
|
||||
// Add fields as key-value pairs
|
||||
for (field_name, field_value) in doc.fields {
|
||||
if field_name != "_id" {
|
||||
doc_array.push(Protocol::BulkString(field_name));
|
||||
doc_array.push(Protocol::BulkString(field_value));
|
||||
}
|
||||
}
|
||||
|
||||
response.push(Protocol::Array(doc_array));
|
||||
}
|
||||
|
||||
Ok(Protocol::Array(response))
|
||||
}
|
||||
|
||||
pub async fn ft_del_cmd(
|
||||
server: &Server,
|
||||
index_name: String,
|
||||
doc_id: String,
|
||||
) -> Result<Protocol, DBError> {
|
||||
let indexes = server.search_indexes.read().unwrap();
|
||||
|
||||
let _search_index = indexes.get(&index_name)
|
||||
.ok_or_else(|| DBError(format!("Index '{}' not found", index_name)))?;
|
||||
|
||||
// For now, return success
|
||||
// In a full implementation, we'd need to add a delete method to TantivySearch
|
||||
println!("Deleting document '{}' from index '{}'", doc_id, index_name);
|
||||
|
||||
Ok(Protocol::SimpleString("1".to_string()))
|
||||
}
|
||||
|
||||
pub async fn ft_info_cmd(
|
||||
server: &Server,
|
||||
index_name: String,
|
||||
) -> Result<Protocol, DBError> {
|
||||
let indexes = server.search_indexes.read().unwrap();
|
||||
|
||||
let search_index = indexes.get(&index_name)
|
||||
.ok_or_else(|| DBError(format!("Index '{}' not found", index_name)))?;
|
||||
|
||||
let info = search_index.get_info()?;
|
||||
|
||||
// Format info as Redis protocol
|
||||
let mut response = Vec::new();
|
||||
|
||||
response.push(Protocol::BulkString("index_name".to_string()));
|
||||
response.push(Protocol::BulkString(info.name));
|
||||
|
||||
response.push(Protocol::BulkString("num_docs".to_string()));
|
||||
response.push(Protocol::BulkString(info.num_docs.to_string()));
|
||||
|
||||
response.push(Protocol::BulkString("num_fields".to_string()));
|
||||
response.push(Protocol::BulkString(info.fields.len().to_string()));
|
||||
|
||||
response.push(Protocol::BulkString("fields".to_string()));
|
||||
let fields_str = info.fields.iter()
|
||||
.map(|f| format!("{}:{}", f.name, f.field_type))
|
||||
.collect::<Vec<_>>()
|
||||
.join(", ");
|
||||
response.push(Protocol::BulkString(fields_str));
|
||||
|
||||
Ok(Protocol::Array(response))
|
||||
}
|
||||
|
||||
pub async fn ft_drop_cmd(
|
||||
server: &Server,
|
||||
index_name: String,
|
||||
) -> Result<Protocol, DBError> {
|
||||
let mut indexes = server.search_indexes.write().unwrap();
|
||||
|
||||
if indexes.remove(&index_name).is_some() {
|
||||
// Also remove the index files from disk
|
||||
let index_path = server.search_index_path().join(&index_name);
|
||||
if index_path.exists() {
|
||||
std::fs::remove_dir_all(index_path)
|
||||
.map_err(|e| DBError(format!("Failed to remove index files: {}", e)))?;
|
||||
}
|
||||
Ok(Protocol::SimpleString("OK".to_string()))
|
||||
} else {
|
||||
Err(DBError(format!("Index '{}' not found", index_name)))
|
||||
}
|
||||
}
|
272
src/server.rs
Normal file
272
src/server.rs
Normal file
@@ -0,0 +1,272 @@
|
||||
use core::str;
|
||||
use std::collections::HashMap;
|
||||
use std::sync::Arc;
|
||||
use tokio::io::AsyncReadExt;
|
||||
use tokio::io::AsyncWriteExt;
|
||||
use tokio::sync::{Mutex, oneshot};
|
||||
use std::sync::RwLock;
|
||||
|
||||
use std::sync::atomic::{AtomicU64, Ordering};
|
||||
|
||||
use crate::cmd::Cmd;
|
||||
use crate::error::DBError;
|
||||
use crate::options;
|
||||
use crate::protocol::Protocol;
|
||||
use crate::storage::Storage;
|
||||
use crate::storage_sled::SledStorage;
|
||||
use crate::storage_trait::StorageBackend;
|
||||
use crate::tantivy_search::TantivySearch;
|
||||
|
||||
#[derive(Clone)]
|
||||
pub struct Server {
|
||||
pub db_cache: Arc<RwLock<HashMap<u64, Arc<dyn StorageBackend>>>>,
|
||||
pub search_indexes: Arc<RwLock<HashMap<String, Arc<TantivySearch>>>>,
|
||||
pub option: options::DBOption,
|
||||
pub client_name: Option<String>,
|
||||
pub selected_db: u64, // Changed from usize to u64
|
||||
pub queued_cmd: Option<Vec<(Cmd, Protocol)>>,
|
||||
|
||||
// BLPOP waiter registry: per (db_index, key) FIFO of waiters
|
||||
pub list_waiters: Arc<Mutex<HashMap<u64, HashMap<String, Vec<Waiter>>>>>,
|
||||
pub waiter_seq: Arc<AtomicU64>,
|
||||
}
|
||||
|
||||
pub struct Waiter {
|
||||
pub id: u64,
|
||||
pub side: PopSide,
|
||||
pub tx: oneshot::Sender<(String, String)>, // (key, element)
|
||||
}
|
||||
|
||||
#[derive(Clone, Copy, Debug, PartialEq, Eq)]
|
||||
pub enum PopSide {
|
||||
Left,
|
||||
Right,
|
||||
}
|
||||
|
||||
impl Server {
|
||||
pub async fn new(option: options::DBOption) -> Self {
|
||||
Server {
|
||||
db_cache: Arc::new(RwLock::new(HashMap::new())),
|
||||
search_indexes: Arc::new(RwLock::new(HashMap::new())),
|
||||
option,
|
||||
client_name: None,
|
||||
selected_db: 0,
|
||||
queued_cmd: None,
|
||||
|
||||
list_waiters: Arc::new(Mutex::new(HashMap::new())),
|
||||
waiter_seq: Arc::new(AtomicU64::new(1)),
|
||||
}
|
||||
}
|
||||
|
||||
pub fn current_storage(&self) -> Result<Arc<dyn StorageBackend>, DBError> {
|
||||
let mut cache = self.db_cache.write().unwrap();
|
||||
|
||||
if let Some(storage) = cache.get(&self.selected_db) {
|
||||
return Ok(storage.clone());
|
||||
}
|
||||
|
||||
|
||||
// Create new database file
|
||||
let db_file_path = std::path::PathBuf::from(self.option.dir.clone())
|
||||
.join(format!("{}.db", self.selected_db));
|
||||
|
||||
// Ensure the directory exists before creating the database file
|
||||
if let Some(parent_dir) = db_file_path.parent() {
|
||||
std::fs::create_dir_all(parent_dir).map_err(|e| {
|
||||
DBError(format!("Failed to create directory {}: {}", parent_dir.display(), e))
|
||||
})?;
|
||||
}
|
||||
|
||||
println!("Creating new db file: {}", db_file_path.display());
|
||||
|
||||
let storage: Arc<dyn StorageBackend> = match self.option.backend {
|
||||
options::BackendType::Redb => {
|
||||
Arc::new(Storage::new(
|
||||
db_file_path,
|
||||
self.should_encrypt_db(self.selected_db),
|
||||
self.option.encryption_key.as_deref()
|
||||
)?)
|
||||
}
|
||||
options::BackendType::Sled => {
|
||||
Arc::new(SledStorage::new(
|
||||
db_file_path,
|
||||
self.should_encrypt_db(self.selected_db),
|
||||
self.option.encryption_key.as_deref()
|
||||
)?)
|
||||
}
|
||||
};
|
||||
|
||||
cache.insert(self.selected_db, storage.clone());
|
||||
Ok(storage)
|
||||
}
|
||||
|
||||
fn should_encrypt_db(&self, db_index: u64) -> bool {
|
||||
// DB 0-9 are non-encrypted, DB 10+ are encrypted
|
||||
self.option.encrypt && db_index >= 10
|
||||
}
|
||||
|
||||
// Add method to get search index path
|
||||
pub fn search_index_path(&self) -> std::path::PathBuf {
|
||||
std::path::PathBuf::from(&self.option.dir).join("search_indexes")
|
||||
}
|
||||
|
||||
// ----- BLPOP waiter helpers -----
|
||||
|
||||
pub async fn register_waiter(&self, db_index: u64, key: &str, side: PopSide) -> (u64, oneshot::Receiver<(String, String)>) {
|
||||
let id = self.waiter_seq.fetch_add(1, Ordering::Relaxed);
|
||||
let (tx, rx) = oneshot::channel::<(String, String)>();
|
||||
|
||||
let mut guard = self.list_waiters.lock().await;
|
||||
let per_db = guard.entry(db_index).or_insert_with(HashMap::new);
|
||||
let q = per_db.entry(key.to_string()).or_insert_with(Vec::new);
|
||||
q.push(Waiter { id, side, tx });
|
||||
(id, rx)
|
||||
}
|
||||
|
||||
pub async fn unregister_waiter(&self, db_index: u64, key: &str, id: u64) {
|
||||
let mut guard = self.list_waiters.lock().await;
|
||||
if let Some(per_db) = guard.get_mut(&db_index) {
|
||||
if let Some(q) = per_db.get_mut(key) {
|
||||
q.retain(|w| w.id != id);
|
||||
if q.is_empty() {
|
||||
per_db.remove(key);
|
||||
}
|
||||
}
|
||||
if per_db.is_empty() {
|
||||
guard.remove(&db_index);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// Called after LPUSH/RPUSH to deliver to blocked BLPOP waiters.
|
||||
pub async fn drain_waiters_after_push(&self, key: &str) -> Result<(), DBError> {
|
||||
let db_index = self.selected_db;
|
||||
|
||||
loop {
|
||||
// Check if any waiter exists
|
||||
let maybe_waiter = {
|
||||
let mut guard = self.list_waiters.lock().await;
|
||||
if let Some(per_db) = guard.get_mut(&db_index) {
|
||||
if let Some(q) = per_db.get_mut(key) {
|
||||
if !q.is_empty() {
|
||||
// Pop FIFO
|
||||
Some(q.remove(0))
|
||||
} else {
|
||||
None
|
||||
}
|
||||
} else {
|
||||
None
|
||||
}
|
||||
} else {
|
||||
None
|
||||
}
|
||||
};
|
||||
|
||||
let waiter = if let Some(w) = maybe_waiter { w } else { break };
|
||||
|
||||
// Pop one element depending on waiter side
|
||||
let elems = match waiter.side {
|
||||
PopSide::Left => self.current_storage()?.lpop(key, 1)?,
|
||||
PopSide::Right => self.current_storage()?.rpop(key, 1)?,
|
||||
};
|
||||
if elems.is_empty() {
|
||||
// Nothing to deliver; re-register waiter at the front to preserve order
|
||||
let mut guard = self.list_waiters.lock().await;
|
||||
let per_db = guard.entry(db_index).or_insert_with(HashMap::new);
|
||||
let q = per_db.entry(key.to_string()).or_insert_with(Vec::new);
|
||||
q.insert(0, waiter);
|
||||
break;
|
||||
} else {
|
||||
let elem = elems[0].clone();
|
||||
// Send to waiter; if receiver dropped, just continue
|
||||
let _ = waiter.tx.send((key.to_string(), elem));
|
||||
// Loop to try to satisfy more waiters if more elements remain
|
||||
continue;
|
||||
}
|
||||
}
|
||||
|
||||
Ok(())
|
||||
}
|
||||
|
||||
pub async fn handle(
|
||||
&mut self,
|
||||
mut stream: tokio::net::TcpStream,
|
||||
) -> Result<(), DBError> {
|
||||
// Accumulate incoming bytes to handle partial RESP frames
|
||||
let mut acc = String::new();
|
||||
let mut buf = vec![0u8; 8192];
|
||||
|
||||
loop {
|
||||
let n = match stream.read(&mut buf).await {
|
||||
Ok(0) => {
|
||||
println!("[handle] connection closed");
|
||||
return Ok(());
|
||||
}
|
||||
Ok(n) => n,
|
||||
Err(e) => {
|
||||
println!("[handle] read error: {:?}", e);
|
||||
return Err(e.into());
|
||||
}
|
||||
};
|
||||
|
||||
// Append to accumulator. RESP for our usage is ASCII-safe.
|
||||
acc.push_str(str::from_utf8(&buf[..n])?);
|
||||
|
||||
// Try to parse as many complete commands as are available in 'acc'.
|
||||
loop {
|
||||
let parsed = Cmd::from(&acc);
|
||||
let (cmd, protocol, remaining) = match parsed {
|
||||
Ok((cmd, protocol, remaining)) => (cmd, protocol, remaining),
|
||||
Err(_e) => {
|
||||
// Incomplete or invalid frame; assume incomplete and wait for more data.
|
||||
// This avoids emitting spurious protocol_error for split frames.
|
||||
break;
|
||||
}
|
||||
};
|
||||
|
||||
// Advance the accumulator to the unparsed remainder
|
||||
acc = remaining.to_string();
|
||||
|
||||
if self.option.debug {
|
||||
println!("\x1b[34;1mgot command: {:?}, protocol: {:?}\x1b[0m", cmd, protocol);
|
||||
} else {
|
||||
println!("got command: {:?}, protocol: {:?}", cmd, protocol);
|
||||
}
|
||||
|
||||
// Check if this is a QUIT command before processing
|
||||
let is_quit = matches!(cmd, Cmd::Quit);
|
||||
|
||||
let res = match cmd.run(self).await {
|
||||
Ok(p) => p,
|
||||
Err(e) => {
|
||||
if self.option.debug {
|
||||
eprintln!("[run error] {:?}", e);
|
||||
}
|
||||
Protocol::err(&format!("ERR {}", e.0))
|
||||
}
|
||||
};
|
||||
|
||||
if self.option.debug {
|
||||
println!("\x1b[34;1mqueued cmd {:?}\x1b[0m", self.queued_cmd);
|
||||
println!("\x1b[32;1mgoing to send response {}\x1b[0m", res.encode());
|
||||
} else {
|
||||
print!("queued cmd {:?}", self.queued_cmd);
|
||||
println!("going to send response {}", res.encode());
|
||||
}
|
||||
|
||||
_ = stream.write(res.encode().as_bytes()).await?;
|
||||
|
||||
// If this was a QUIT command, close the connection
|
||||
if is_quit {
|
||||
println!("[handle] QUIT command received, closing connection");
|
||||
return Ok(());
|
||||
}
|
||||
|
||||
// Continue parsing any further complete commands already in 'acc'
|
||||
if acc.is_empty() {
|
||||
break;
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
287
src/storage/mod.rs
Normal file
287
src/storage/mod.rs
Normal file
@@ -0,0 +1,287 @@
|
||||
use std::{
|
||||
path::Path,
|
||||
sync::Arc,
|
||||
time::{SystemTime, UNIX_EPOCH},
|
||||
};
|
||||
|
||||
use redb::{Database, TableDefinition};
|
||||
use serde::{Deserialize, Serialize};
|
||||
|
||||
use crate::crypto::CryptoFactory;
|
||||
use crate::error::DBError;
|
||||
|
||||
// Re-export modules
|
||||
mod storage_basic;
|
||||
mod storage_hset;
|
||||
mod storage_lists;
|
||||
mod storage_extra;
|
||||
|
||||
// Re-export implementations
|
||||
// Note: These imports are used by the impl blocks in the submodules
|
||||
// The compiler shows them as unused because they're not directly used in this file
|
||||
// but they're needed for the Storage struct methods to be available
|
||||
pub use storage_extra::*;
|
||||
|
||||
// Table definitions for different Redis data types
|
||||
const TYPES_TABLE: TableDefinition<&str, &str> = TableDefinition::new("types");
|
||||
const STRINGS_TABLE: TableDefinition<&str, &[u8]> = TableDefinition::new("strings");
|
||||
const HASHES_TABLE: TableDefinition<(&str, &str), &[u8]> = TableDefinition::new("hashes");
|
||||
const LISTS_TABLE: TableDefinition<&str, &[u8]> = TableDefinition::new("lists");
|
||||
const STREAMS_META_TABLE: TableDefinition<&str, &[u8]> = TableDefinition::new("streams_meta");
|
||||
const STREAMS_DATA_TABLE: TableDefinition<(&str, &str), &[u8]> = TableDefinition::new("streams_data");
|
||||
const ENCRYPTED_TABLE: TableDefinition<&str, u8> = TableDefinition::new("encrypted");
|
||||
const EXPIRATION_TABLE: TableDefinition<&str, u64> = TableDefinition::new("expiration");
|
||||
|
||||
#[derive(Serialize, Deserialize, Debug, Clone)]
|
||||
pub struct StreamEntry {
|
||||
pub fields: Vec<(String, String)>,
|
||||
}
|
||||
|
||||
#[derive(Serialize, Deserialize, Debug, Clone)]
|
||||
pub struct ListValue {
|
||||
pub elements: Vec<String>,
|
||||
}
|
||||
|
||||
#[inline]
|
||||
pub fn now_in_millis() -> u128 {
|
||||
let start = SystemTime::now();
|
||||
let duration_since_epoch = start.duration_since(UNIX_EPOCH).unwrap();
|
||||
duration_since_epoch.as_millis()
|
||||
}
|
||||
|
||||
pub struct Storage {
|
||||
db: Database,
|
||||
crypto: Option<CryptoFactory>,
|
||||
}
|
||||
|
||||
impl Storage {
|
||||
pub fn new(path: impl AsRef<Path>, should_encrypt: bool, master_key: Option<&str>) -> Result<Self, DBError> {
|
||||
let db = Database::create(path)?;
|
||||
|
||||
// Create tables if they don't exist
|
||||
let write_txn = db.begin_write()?;
|
||||
{
|
||||
let _ = write_txn.open_table(TYPES_TABLE)?;
|
||||
let _ = write_txn.open_table(STRINGS_TABLE)?;
|
||||
let _ = write_txn.open_table(HASHES_TABLE)?;
|
||||
let _ = write_txn.open_table(LISTS_TABLE)?;
|
||||
let _ = write_txn.open_table(STREAMS_META_TABLE)?;
|
||||
let _ = write_txn.open_table(STREAMS_DATA_TABLE)?;
|
||||
let _ = write_txn.open_table(ENCRYPTED_TABLE)?;
|
||||
let _ = write_txn.open_table(EXPIRATION_TABLE)?;
|
||||
}
|
||||
write_txn.commit()?;
|
||||
|
||||
// Check if database was previously encrypted
|
||||
let read_txn = db.begin_read()?;
|
||||
let encrypted_table = read_txn.open_table(ENCRYPTED_TABLE)?;
|
||||
let was_encrypted = encrypted_table.get("encrypted")?.map(|v| v.value() == 1).unwrap_or(false);
|
||||
drop(read_txn);
|
||||
|
||||
let crypto = if should_encrypt || was_encrypted {
|
||||
if let Some(key) = master_key {
|
||||
Some(CryptoFactory::new(key.as_bytes()))
|
||||
} else {
|
||||
return Err(DBError("Encryption requested but no master key provided".to_string()));
|
||||
}
|
||||
} else {
|
||||
None
|
||||
};
|
||||
|
||||
// If we're enabling encryption for the first time, mark it
|
||||
if should_encrypt && !was_encrypted {
|
||||
let write_txn = db.begin_write()?;
|
||||
{
|
||||
let mut encrypted_table = write_txn.open_table(ENCRYPTED_TABLE)?;
|
||||
encrypted_table.insert("encrypted", &1u8)?;
|
||||
}
|
||||
write_txn.commit()?;
|
||||
}
|
||||
|
||||
Ok(Storage {
|
||||
db,
|
||||
crypto,
|
||||
})
|
||||
}
|
||||
|
||||
pub fn is_encrypted(&self) -> bool {
|
||||
self.crypto.is_some()
|
||||
}
|
||||
|
||||
// Helper methods for encryption
|
||||
fn encrypt_if_needed(&self, data: &[u8]) -> Result<Vec<u8>, DBError> {
|
||||
if let Some(crypto) = &self.crypto {
|
||||
Ok(crypto.encrypt(data))
|
||||
} else {
|
||||
Ok(data.to_vec())
|
||||
}
|
||||
}
|
||||
|
||||
fn decrypt_if_needed(&self, data: &[u8]) -> Result<Vec<u8>, DBError> {
|
||||
if let Some(crypto) = &self.crypto {
|
||||
Ok(crypto.decrypt(data)?)
|
||||
} else {
|
||||
Ok(data.to_vec())
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
use crate::storage_trait::StorageBackend;
|
||||
|
||||
impl StorageBackend for Storage {
|
||||
fn get(&self, key: &str) -> Result<Option<String>, DBError> {
|
||||
self.get(key)
|
||||
}
|
||||
|
||||
fn set(&self, key: String, value: String) -> Result<(), DBError> {
|
||||
self.set(key, value)
|
||||
}
|
||||
|
||||
fn setx(&self, key: String, value: String, expire_ms: u128) -> Result<(), DBError> {
|
||||
self.setx(key, value, expire_ms)
|
||||
}
|
||||
|
||||
fn del(&self, key: String) -> Result<(), DBError> {
|
||||
self.del(key)
|
||||
}
|
||||
|
||||
fn exists(&self, key: &str) -> Result<bool, DBError> {
|
||||
self.exists(key)
|
||||
}
|
||||
|
||||
fn keys(&self, pattern: &str) -> Result<Vec<String>, DBError> {
|
||||
self.keys(pattern)
|
||||
}
|
||||
|
||||
fn dbsize(&self) -> Result<i64, DBError> {
|
||||
self.dbsize()
|
||||
}
|
||||
|
||||
fn flushdb(&self) -> Result<(), DBError> {
|
||||
self.flushdb()
|
||||
}
|
||||
|
||||
fn get_key_type(&self, key: &str) -> Result<Option<String>, DBError> {
|
||||
self.get_key_type(key)
|
||||
}
|
||||
|
||||
fn scan(&self, cursor: u64, pattern: Option<&str>, count: Option<u64>) -> Result<(u64, Vec<(String, String)>), DBError> {
|
||||
self.scan(cursor, pattern, count)
|
||||
}
|
||||
|
||||
fn hscan(&self, key: &str, cursor: u64, pattern: Option<&str>, count: Option<u64>) -> Result<(u64, Vec<(String, String)>), DBError> {
|
||||
self.hscan(key, cursor, pattern, count)
|
||||
}
|
||||
|
||||
fn hset(&self, key: &str, pairs: Vec<(String, String)>) -> Result<i64, DBError> {
|
||||
self.hset(key, pairs)
|
||||
}
|
||||
|
||||
fn hget(&self, key: &str, field: &str) -> Result<Option<String>, DBError> {
|
||||
self.hget(key, field)
|
||||
}
|
||||
|
||||
fn hgetall(&self, key: &str) -> Result<Vec<(String, String)>, DBError> {
|
||||
self.hgetall(key)
|
||||
}
|
||||
|
||||
fn hdel(&self, key: &str, fields: Vec<String>) -> Result<i64, DBError> {
|
||||
self.hdel(key, fields)
|
||||
}
|
||||
|
||||
fn hexists(&self, key: &str, field: &str) -> Result<bool, DBError> {
|
||||
self.hexists(key, field)
|
||||
}
|
||||
|
||||
fn hkeys(&self, key: &str) -> Result<Vec<String>, DBError> {
|
||||
self.hkeys(key)
|
||||
}
|
||||
|
||||
fn hvals(&self, key: &str) -> Result<Vec<String>, DBError> {
|
||||
self.hvals(key)
|
||||
}
|
||||
|
||||
fn hlen(&self, key: &str) -> Result<i64, DBError> {
|
||||
self.hlen(key)
|
||||
}
|
||||
|
||||
fn hmget(&self, key: &str, fields: Vec<String>) -> Result<Vec<Option<String>>, DBError> {
|
||||
self.hmget(key, fields)
|
||||
}
|
||||
|
||||
fn hsetnx(&self, key: &str, field: &str, value: &str) -> Result<bool, DBError> {
|
||||
self.hsetnx(key, field, value)
|
||||
}
|
||||
|
||||
fn lpush(&self, key: &str, elements: Vec<String>) -> Result<i64, DBError> {
|
||||
self.lpush(key, elements)
|
||||
}
|
||||
|
||||
fn rpush(&self, key: &str, elements: Vec<String>) -> Result<i64, DBError> {
|
||||
self.rpush(key, elements)
|
||||
}
|
||||
|
||||
fn lpop(&self, key: &str, count: u64) -> Result<Vec<String>, DBError> {
|
||||
self.lpop(key, count)
|
||||
}
|
||||
|
||||
fn rpop(&self, key: &str, count: u64) -> Result<Vec<String>, DBError> {
|
||||
self.rpop(key, count)
|
||||
}
|
||||
|
||||
fn llen(&self, key: &str) -> Result<i64, DBError> {
|
||||
self.llen(key)
|
||||
}
|
||||
|
||||
fn lindex(&self, key: &str, index: i64) -> Result<Option<String>, DBError> {
|
||||
self.lindex(key, index)
|
||||
}
|
||||
|
||||
fn lrange(&self, key: &str, start: i64, stop: i64) -> Result<Vec<String>, DBError> {
|
||||
self.lrange(key, start, stop)
|
||||
}
|
||||
|
||||
fn ltrim(&self, key: &str, start: i64, stop: i64) -> Result<(), DBError> {
|
||||
self.ltrim(key, start, stop)
|
||||
}
|
||||
|
||||
fn lrem(&self, key: &str, count: i64, element: &str) -> Result<i64, DBError> {
|
||||
self.lrem(key, count, element)
|
||||
}
|
||||
|
||||
fn ttl(&self, key: &str) -> Result<i64, DBError> {
|
||||
self.ttl(key)
|
||||
}
|
||||
|
||||
fn expire_seconds(&self, key: &str, secs: u64) -> Result<bool, DBError> {
|
||||
self.expire_seconds(key, secs)
|
||||
}
|
||||
|
||||
fn pexpire_millis(&self, key: &str, ms: u128) -> Result<bool, DBError> {
|
||||
self.pexpire_millis(key, ms)
|
||||
}
|
||||
|
||||
fn persist(&self, key: &str) -> Result<bool, DBError> {
|
||||
self.persist(key)
|
||||
}
|
||||
|
||||
fn expire_at_seconds(&self, key: &str, ts_secs: i64) -> Result<bool, DBError> {
|
||||
self.expire_at_seconds(key, ts_secs)
|
||||
}
|
||||
|
||||
fn pexpire_at_millis(&self, key: &str, ts_ms: i64) -> Result<bool, DBError> {
|
||||
self.pexpire_at_millis(key, ts_ms)
|
||||
}
|
||||
|
||||
fn is_encrypted(&self) -> bool {
|
||||
self.is_encrypted()
|
||||
}
|
||||
|
||||
fn info(&self) -> Result<Vec<(String, String)>, DBError> {
|
||||
self.info()
|
||||
}
|
||||
|
||||
fn clone_arc(&self) -> Arc<dyn StorageBackend> {
|
||||
unimplemented!("Storage cloning not yet implemented for redb backend")
|
||||
}
|
||||
}
|
245
src/storage/storage_basic.rs
Normal file
245
src/storage/storage_basic.rs
Normal file
@@ -0,0 +1,245 @@
|
||||
use redb::{ReadableTable};
|
||||
use crate::error::DBError;
|
||||
use super::*;
|
||||
|
||||
impl Storage {
|
||||
pub fn flushdb(&self) -> Result<(), DBError> {
|
||||
let write_txn = self.db.begin_write()?;
|
||||
{
|
||||
let mut types_table = write_txn.open_table(TYPES_TABLE)?;
|
||||
let mut strings_table = write_txn.open_table(STRINGS_TABLE)?;
|
||||
let mut hashes_table = write_txn.open_table(HASHES_TABLE)?;
|
||||
let mut lists_table = write_txn.open_table(LISTS_TABLE)?;
|
||||
let mut streams_meta_table = write_txn.open_table(STREAMS_META_TABLE)?;
|
||||
let mut streams_data_table = write_txn.open_table(STREAMS_DATA_TABLE)?;
|
||||
let mut expiration_table = write_txn.open_table(EXPIRATION_TABLE)?;
|
||||
|
||||
// inefficient, but there is no other way
|
||||
let keys: Vec<String> = types_table.iter()?.map(|item| item.unwrap().0.value().to_string()).collect();
|
||||
for key in keys {
|
||||
types_table.remove(key.as_str())?;
|
||||
}
|
||||
let keys: Vec<String> = strings_table.iter()?.map(|item| item.unwrap().0.value().to_string()).collect();
|
||||
for key in keys {
|
||||
strings_table.remove(key.as_str())?;
|
||||
}
|
||||
let keys: Vec<(String, String)> = hashes_table
|
||||
.iter()?
|
||||
.map(|item| {
|
||||
let binding = item.unwrap();
|
||||
let (k, f) = binding.0.value();
|
||||
(k.to_string(), f.to_string())
|
||||
})
|
||||
.collect();
|
||||
for (key, field) in keys {
|
||||
hashes_table.remove((key.as_str(), field.as_str()))?;
|
||||
}
|
||||
let keys: Vec<String> = lists_table.iter()?.map(|item| item.unwrap().0.value().to_string()).collect();
|
||||
for key in keys {
|
||||
lists_table.remove(key.as_str())?;
|
||||
}
|
||||
let keys: Vec<String> = streams_meta_table.iter()?.map(|item| item.unwrap().0.value().to_string()).collect();
|
||||
for key in keys {
|
||||
streams_meta_table.remove(key.as_str())?;
|
||||
}
|
||||
let keys: Vec<(String,String)> = streams_data_table.iter()?.map(|item| {
|
||||
let binding = item.unwrap();
|
||||
let (key, field) = binding.0.value();
|
||||
(key.to_string(), field.to_string())
|
||||
}).collect();
|
||||
for (key, field) in keys {
|
||||
streams_data_table.remove((key.as_str(), field.as_str()))?;
|
||||
}
|
||||
let keys: Vec<String> = expiration_table.iter()?.map(|item| item.unwrap().0.value().to_string()).collect();
|
||||
for key in keys {
|
||||
expiration_table.remove(key.as_str())?;
|
||||
}
|
||||
}
|
||||
write_txn.commit()?;
|
||||
Ok(())
|
||||
}
|
||||
|
||||
pub fn get_key_type(&self, key: &str) -> Result<Option<String>, DBError> {
|
||||
let read_txn = self.db.begin_read()?;
|
||||
let table = read_txn.open_table(TYPES_TABLE)?;
|
||||
|
||||
// Before returning type, check for expiration
|
||||
if let Some(type_val) = table.get(key)? {
|
||||
if type_val.value() == "string" {
|
||||
let expiration_table = read_txn.open_table(EXPIRATION_TABLE)?;
|
||||
if let Some(expires_at) = expiration_table.get(key)? {
|
||||
if now_in_millis() > expires_at.value() as u128 {
|
||||
// The key is expired, so it effectively has no type
|
||||
return Ok(None);
|
||||
}
|
||||
}
|
||||
}
|
||||
Ok(Some(type_val.value().to_string()))
|
||||
} else {
|
||||
Ok(None)
|
||||
}
|
||||
}
|
||||
|
||||
// ✅ ENCRYPTION APPLIED: Value is encrypted/decrypted
|
||||
pub fn get(&self, key: &str) -> Result<Option<String>, DBError> {
|
||||
let read_txn = self.db.begin_read()?;
|
||||
|
||||
let types_table = read_txn.open_table(TYPES_TABLE)?;
|
||||
match types_table.get(key)? {
|
||||
Some(type_val) if type_val.value() == "string" => {
|
||||
// Check expiration first (unencrypted)
|
||||
let expiration_table = read_txn.open_table(EXPIRATION_TABLE)?;
|
||||
if let Some(expires_at) = expiration_table.get(key)? {
|
||||
if now_in_millis() > expires_at.value() as u128 {
|
||||
drop(read_txn);
|
||||
self.del(key.to_string())?;
|
||||
return Ok(None);
|
||||
}
|
||||
}
|
||||
|
||||
// Get and decrypt value
|
||||
let strings_table = read_txn.open_table(STRINGS_TABLE)?;
|
||||
match strings_table.get(key)? {
|
||||
Some(data) => {
|
||||
let decrypted = self.decrypt_if_needed(data.value())?;
|
||||
let value = String::from_utf8(decrypted)?;
|
||||
Ok(Some(value))
|
||||
}
|
||||
None => Ok(None),
|
||||
}
|
||||
}
|
||||
_ => Ok(None),
|
||||
}
|
||||
}
|
||||
|
||||
// ✅ ENCRYPTION APPLIED: Value is encrypted before storage
|
||||
pub fn set(&self, key: String, value: String) -> Result<(), DBError> {
|
||||
let write_txn = self.db.begin_write()?;
|
||||
|
||||
{
|
||||
let mut types_table = write_txn.open_table(TYPES_TABLE)?;
|
||||
types_table.insert(key.as_str(), "string")?;
|
||||
|
||||
let mut strings_table = write_txn.open_table(STRINGS_TABLE)?;
|
||||
// Only encrypt the value, not expiration
|
||||
let encrypted = self.encrypt_if_needed(value.as_bytes())?;
|
||||
strings_table.insert(key.as_str(), encrypted.as_slice())?;
|
||||
|
||||
// Remove any existing expiration since this is a regular SET
|
||||
let mut expiration_table = write_txn.open_table(EXPIRATION_TABLE)?;
|
||||
expiration_table.remove(key.as_str())?;
|
||||
}
|
||||
|
||||
write_txn.commit()?;
|
||||
Ok(())
|
||||
}
|
||||
|
||||
// ✅ ENCRYPTION APPLIED: Value is encrypted before storage
|
||||
pub fn setx(&self, key: String, value: String, expire_ms: u128) -> Result<(), DBError> {
|
||||
let write_txn = self.db.begin_write()?;
|
||||
|
||||
{
|
||||
let mut types_table = write_txn.open_table(TYPES_TABLE)?;
|
||||
types_table.insert(key.as_str(), "string")?;
|
||||
|
||||
let mut strings_table = write_txn.open_table(STRINGS_TABLE)?;
|
||||
// Only encrypt the value
|
||||
let encrypted = self.encrypt_if_needed(value.as_bytes())?;
|
||||
strings_table.insert(key.as_str(), encrypted.as_slice())?;
|
||||
|
||||
// Store expiration separately (unencrypted)
|
||||
let mut expiration_table = write_txn.open_table(EXPIRATION_TABLE)?;
|
||||
let expires_at = expire_ms + now_in_millis();
|
||||
expiration_table.insert(key.as_str(), &(expires_at as u64))?;
|
||||
}
|
||||
|
||||
write_txn.commit()?;
|
||||
Ok(())
|
||||
}
|
||||
|
||||
pub fn del(&self, key: String) -> Result<(), DBError> {
|
||||
let write_txn = self.db.begin_write()?;
|
||||
|
||||
{
|
||||
let mut types_table = write_txn.open_table(TYPES_TABLE)?;
|
||||
let mut strings_table = write_txn.open_table(STRINGS_TABLE)?;
|
||||
let mut hashes_table: redb::Table<(&str, &str), &[u8]> = write_txn.open_table(HASHES_TABLE)?;
|
||||
let mut lists_table = write_txn.open_table(LISTS_TABLE)?;
|
||||
|
||||
// Remove from type table
|
||||
types_table.remove(key.as_str())?;
|
||||
|
||||
// Remove from strings table
|
||||
strings_table.remove(key.as_str())?;
|
||||
|
||||
// Remove all hash fields for this key
|
||||
let mut to_remove = Vec::new();
|
||||
let mut iter = hashes_table.iter()?;
|
||||
while let Some(entry) = iter.next() {
|
||||
let entry = entry?;
|
||||
let (hash_key, field) = entry.0.value();
|
||||
if hash_key == key.as_str() {
|
||||
to_remove.push((hash_key.to_string(), field.to_string()));
|
||||
}
|
||||
}
|
||||
drop(iter);
|
||||
|
||||
for (hash_key, field) in to_remove {
|
||||
hashes_table.remove((hash_key.as_str(), field.as_str()))?;
|
||||
}
|
||||
|
||||
// Remove from lists table
|
||||
lists_table.remove(key.as_str())?;
|
||||
|
||||
// Also remove expiration
|
||||
let mut expiration_table = write_txn.open_table(EXPIRATION_TABLE)?;
|
||||
expiration_table.remove(key.as_str())?;
|
||||
}
|
||||
|
||||
write_txn.commit()?;
|
||||
Ok(())
|
||||
}
|
||||
|
||||
pub fn keys(&self, pattern: &str) -> Result<Vec<String>, DBError> {
|
||||
let read_txn = self.db.begin_read()?;
|
||||
let table = read_txn.open_table(TYPES_TABLE)?;
|
||||
|
||||
let mut keys = Vec::new();
|
||||
let mut iter = table.iter()?;
|
||||
while let Some(entry) = iter.next() {
|
||||
let key = entry?.0.value().to_string();
|
||||
if pattern == "*" || super::storage_extra::glob_match(pattern, &key) {
|
||||
keys.push(key);
|
||||
}
|
||||
}
|
||||
|
||||
Ok(keys)
|
||||
}
|
||||
}
|
||||
|
||||
impl Storage {
|
||||
pub fn dbsize(&self) -> Result<i64, DBError> {
|
||||
let read_txn = self.db.begin_read()?;
|
||||
let types_table = read_txn.open_table(TYPES_TABLE)?;
|
||||
let expiration_table = read_txn.open_table(EXPIRATION_TABLE)?;
|
||||
|
||||
let mut count: i64 = 0;
|
||||
let mut iter = types_table.iter()?;
|
||||
while let Some(entry) = iter.next() {
|
||||
let entry = entry?;
|
||||
let key = entry.0.value();
|
||||
let ty = entry.1.value();
|
||||
|
||||
if ty == "string" {
|
||||
if let Some(expires_at) = expiration_table.get(key)? {
|
||||
if now_in_millis() > expires_at.value() as u128 {
|
||||
// Skip logically expired string keys
|
||||
continue;
|
||||
}
|
||||
}
|
||||
}
|
||||
count += 1;
|
||||
}
|
||||
Ok(count)
|
||||
}
|
||||
}
|
286
src/storage/storage_extra.rs
Normal file
286
src/storage/storage_extra.rs
Normal file
@@ -0,0 +1,286 @@
|
||||
use redb::{ReadableTable};
|
||||
use crate::error::DBError;
|
||||
use super::*;
|
||||
|
||||
impl Storage {
|
||||
// ✅ ENCRYPTION APPLIED: Values are decrypted after retrieval
|
||||
pub fn scan(&self, cursor: u64, pattern: Option<&str>, count: Option<u64>) -> Result<(u64, Vec<(String, String)>), DBError> {
|
||||
let read_txn = self.db.begin_read()?;
|
||||
let types_table = read_txn.open_table(TYPES_TABLE)?;
|
||||
let strings_table = read_txn.open_table(STRINGS_TABLE)?;
|
||||
|
||||
let mut result = Vec::new();
|
||||
let mut current_cursor = 0u64;
|
||||
let limit = count.unwrap_or(10) as usize;
|
||||
|
||||
let mut iter = types_table.iter()?;
|
||||
while let Some(entry) = iter.next() {
|
||||
let entry = entry?;
|
||||
let key = entry.0.value().to_string();
|
||||
let key_type = entry.1.value().to_string();
|
||||
|
||||
if current_cursor >= cursor {
|
||||
// Apply pattern matching if specified
|
||||
let matches = if let Some(pat) = pattern {
|
||||
glob_match(pat, &key)
|
||||
} else {
|
||||
true
|
||||
};
|
||||
|
||||
if matches {
|
||||
// For scan, we return key-value pairs for string types
|
||||
if key_type == "string" {
|
||||
if let Some(data) = strings_table.get(key.as_str())? {
|
||||
let decrypted = self.decrypt_if_needed(data.value())?;
|
||||
let value = String::from_utf8(decrypted)?;
|
||||
result.push((key, value));
|
||||
} else {
|
||||
result.push((key, String::new()));
|
||||
}
|
||||
} else {
|
||||
// For non-string types, just return the key with type as value
|
||||
result.push((key, key_type));
|
||||
}
|
||||
|
||||
if result.len() >= limit {
|
||||
break;
|
||||
}
|
||||
}
|
||||
}
|
||||
current_cursor += 1;
|
||||
}
|
||||
|
||||
let next_cursor = if result.len() < limit { 0 } else { current_cursor };
|
||||
Ok((next_cursor, result))
|
||||
}
|
||||
|
||||
pub fn ttl(&self, key: &str) -> Result<i64, DBError> {
|
||||
let read_txn = self.db.begin_read()?;
|
||||
let types_table = read_txn.open_table(TYPES_TABLE)?;
|
||||
|
||||
match types_table.get(key)? {
|
||||
Some(type_val) if type_val.value() == "string" => {
|
||||
let expiration_table = read_txn.open_table(EXPIRATION_TABLE)?;
|
||||
match expiration_table.get(key)? {
|
||||
Some(expires_at) => {
|
||||
let now = now_in_millis();
|
||||
let expires_at_ms = expires_at.value() as u128;
|
||||
if now >= expires_at_ms {
|
||||
Ok(-2) // Key has expired
|
||||
} else {
|
||||
Ok(((expires_at_ms - now) / 1000) as i64) // TTL in seconds
|
||||
}
|
||||
}
|
||||
None => Ok(-1), // Key exists but has no expiration
|
||||
}
|
||||
}
|
||||
Some(_) => Ok(-1), // Key exists but is not a string (no expiration support for other types)
|
||||
None => Ok(-2), // Key does not exist
|
||||
}
|
||||
}
|
||||
|
||||
pub fn exists(&self, key: &str) -> Result<bool, DBError> {
|
||||
let read_txn = self.db.begin_read()?;
|
||||
let types_table = read_txn.open_table(TYPES_TABLE)?;
|
||||
|
||||
match types_table.get(key)? {
|
||||
Some(type_val) if type_val.value() == "string" => {
|
||||
// Check if string key has expired
|
||||
let expiration_table = read_txn.open_table(EXPIRATION_TABLE)?;
|
||||
if let Some(expires_at) = expiration_table.get(key)? {
|
||||
if now_in_millis() > expires_at.value() as u128 {
|
||||
return Ok(false); // Key has expired
|
||||
}
|
||||
}
|
||||
Ok(true)
|
||||
}
|
||||
Some(_) => Ok(true), // Key exists and is not a string
|
||||
None => Ok(false), // Key does not exist
|
||||
}
|
||||
}
|
||||
|
||||
// -------- Expiration helpers (string keys only, consistent with TTL/EXISTS) --------
|
||||
|
||||
// Set expiry in seconds; returns true if applied (key exists and is string), false otherwise
|
||||
pub fn expire_seconds(&self, key: &str, secs: u64) -> Result<bool, DBError> {
|
||||
// Determine eligibility first to avoid holding borrows across commit
|
||||
let mut applied = false;
|
||||
let write_txn = self.db.begin_write()?;
|
||||
{
|
||||
let types_table = write_txn.open_table(TYPES_TABLE)?;
|
||||
let is_string = types_table
|
||||
.get(key)?
|
||||
.map(|v| v.value() == "string")
|
||||
.unwrap_or(false);
|
||||
if is_string {
|
||||
let mut expiration_table = write_txn.open_table(EXPIRATION_TABLE)?;
|
||||
let expires_at = now_in_millis() + (secs as u128) * 1000;
|
||||
expiration_table.insert(key, &(expires_at as u64))?;
|
||||
applied = true;
|
||||
}
|
||||
}
|
||||
write_txn.commit()?;
|
||||
Ok(applied)
|
||||
}
|
||||
|
||||
// Set expiry in milliseconds; returns true if applied (key exists and is string), false otherwise
|
||||
pub fn pexpire_millis(&self, key: &str, ms: u128) -> Result<bool, DBError> {
|
||||
let mut applied = false;
|
||||
let write_txn = self.db.begin_write()?;
|
||||
{
|
||||
let types_table = write_txn.open_table(TYPES_TABLE)?;
|
||||
let is_string = types_table
|
||||
.get(key)?
|
||||
.map(|v| v.value() == "string")
|
||||
.unwrap_or(false);
|
||||
if is_string {
|
||||
let mut expiration_table = write_txn.open_table(EXPIRATION_TABLE)?;
|
||||
let expires_at = now_in_millis() + ms;
|
||||
expiration_table.insert(key, &(expires_at as u64))?;
|
||||
applied = true;
|
||||
}
|
||||
}
|
||||
write_txn.commit()?;
|
||||
Ok(applied)
|
||||
}
|
||||
|
||||
// Remove expiry if present; returns true if removed, false otherwise
|
||||
pub fn persist(&self, key: &str) -> Result<bool, DBError> {
|
||||
let mut removed = false;
|
||||
let write_txn = self.db.begin_write()?;
|
||||
{
|
||||
let types_table = write_txn.open_table(TYPES_TABLE)?;
|
||||
let is_string = types_table
|
||||
.get(key)?
|
||||
.map(|v| v.value() == "string")
|
||||
.unwrap_or(false);
|
||||
if is_string {
|
||||
let mut expiration_table = write_txn.open_table(EXPIRATION_TABLE)?;
|
||||
if expiration_table.remove(key)?.is_some() {
|
||||
removed = true;
|
||||
}
|
||||
}
|
||||
}
|
||||
write_txn.commit()?;
|
||||
Ok(removed)
|
||||
}
|
||||
|
||||
// Absolute EXPIREAT in seconds since epoch
|
||||
// Returns true if applied (key exists and is string), false otherwise
|
||||
pub fn expire_at_seconds(&self, key: &str, ts_secs: i64) -> Result<bool, DBError> {
|
||||
let mut applied = false;
|
||||
let write_txn = self.db.begin_write()?;
|
||||
{
|
||||
let types_table = write_txn.open_table(TYPES_TABLE)?;
|
||||
let is_string = types_table
|
||||
.get(key)?
|
||||
.map(|v| v.value() == "string")
|
||||
.unwrap_or(false);
|
||||
if is_string {
|
||||
let mut expiration_table = write_txn.open_table(EXPIRATION_TABLE)?;
|
||||
let expires_at_ms: u128 = if ts_secs <= 0 { 0 } else { (ts_secs as u128) * 1000 };
|
||||
expiration_table.insert(key, &((expires_at_ms as u64)))?;
|
||||
applied = true;
|
||||
}
|
||||
}
|
||||
write_txn.commit()?;
|
||||
Ok(applied)
|
||||
}
|
||||
|
||||
// Absolute PEXPIREAT in milliseconds since epoch
|
||||
// Returns true if applied (key exists and is string), false otherwise
|
||||
pub fn pexpire_at_millis(&self, key: &str, ts_ms: i64) -> Result<bool, DBError> {
|
||||
let mut applied = false;
|
||||
let write_txn = self.db.begin_write()?;
|
||||
{
|
||||
let types_table = write_txn.open_table(TYPES_TABLE)?;
|
||||
let is_string = types_table
|
||||
.get(key)?
|
||||
.map(|v| v.value() == "string")
|
||||
.unwrap_or(false);
|
||||
if is_string {
|
||||
let mut expiration_table = write_txn.open_table(EXPIRATION_TABLE)?;
|
||||
let expires_at_ms: u128 = if ts_ms <= 0 { 0 } else { ts_ms as u128 };
|
||||
expiration_table.insert(key, &((expires_at_ms as u64)))?;
|
||||
applied = true;
|
||||
}
|
||||
}
|
||||
write_txn.commit()?;
|
||||
Ok(applied)
|
||||
}
|
||||
|
||||
pub fn info(&self) -> Result<Vec<(String, String)>, DBError> {
|
||||
let dbsize = self.dbsize()?;
|
||||
Ok(vec![
|
||||
("db_size".to_string(), dbsize.to_string()),
|
||||
("is_encrypted".to_string(), self.is_encrypted().to_string()),
|
||||
])
|
||||
}
|
||||
}
|
||||
|
||||
// Utility function for glob pattern matching
|
||||
pub fn glob_match(pattern: &str, text: &str) -> bool {
|
||||
if pattern == "*" {
|
||||
return true;
|
||||
}
|
||||
|
||||
// Simple glob matching - supports * and ? wildcards
|
||||
let pattern_chars: Vec<char> = pattern.chars().collect();
|
||||
let text_chars: Vec<char> = text.chars().collect();
|
||||
|
||||
fn match_recursive(pattern: &[char], text: &[char], pi: usize, ti: usize) -> bool {
|
||||
if pi >= pattern.len() {
|
||||
return ti >= text.len();
|
||||
}
|
||||
|
||||
if ti >= text.len() {
|
||||
// Check if remaining pattern is all '*'
|
||||
return pattern[pi..].iter().all(|&c| c == '*');
|
||||
}
|
||||
|
||||
match pattern[pi] {
|
||||
'*' => {
|
||||
// Try matching zero or more characters
|
||||
for i in ti..=text.len() {
|
||||
if match_recursive(pattern, text, pi + 1, i) {
|
||||
return true;
|
||||
}
|
||||
}
|
||||
false
|
||||
}
|
||||
'?' => {
|
||||
// Match exactly one character
|
||||
match_recursive(pattern, text, pi + 1, ti + 1)
|
||||
}
|
||||
c => {
|
||||
// Match exact character
|
||||
if text[ti] == c {
|
||||
match_recursive(pattern, text, pi + 1, ti + 1)
|
||||
} else {
|
||||
false
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
match_recursive(&pattern_chars, &text_chars, 0, 0)
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use super::*;
|
||||
|
||||
#[test]
|
||||
fn test_glob_match() {
|
||||
assert!(glob_match("*", "anything"));
|
||||
assert!(glob_match("hello", "hello"));
|
||||
assert!(!glob_match("hello", "world"));
|
||||
assert!(glob_match("h*o", "hello"));
|
||||
assert!(glob_match("h*o", "ho"));
|
||||
assert!(!glob_match("h*o", "hi"));
|
||||
assert!(glob_match("h?llo", "hello"));
|
||||
assert!(!glob_match("h?llo", "hllo"));
|
||||
assert!(glob_match("*test*", "this_is_a_test_string"));
|
||||
assert!(!glob_match("*test*", "this_is_a_string"));
|
||||
}
|
||||
}
|
377
src/storage/storage_hset.rs
Normal file
377
src/storage/storage_hset.rs
Normal file
@@ -0,0 +1,377 @@
|
||||
use redb::{ReadableTable};
|
||||
use crate::error::DBError;
|
||||
use super::*;
|
||||
|
||||
impl Storage {
|
||||
// ✅ ENCRYPTION APPLIED: Values are encrypted before storage
|
||||
pub fn hset(&self, key: &str, pairs: Vec<(String, String)>) -> Result<i64, DBError> {
|
||||
let write_txn = self.db.begin_write()?;
|
||||
let mut new_fields = 0i64;
|
||||
|
||||
{
|
||||
let mut types_table = write_txn.open_table(TYPES_TABLE)?;
|
||||
let mut hashes_table = write_txn.open_table(HASHES_TABLE)?;
|
||||
|
||||
let key_type = {
|
||||
let access_guard = types_table.get(key)?;
|
||||
access_guard.map(|v| v.value().to_string())
|
||||
};
|
||||
|
||||
match key_type.as_deref() {
|
||||
Some("hash") | None => { // Proceed if hash or new key
|
||||
// Set the type to hash (only if new key or existing hash)
|
||||
types_table.insert(key, "hash")?;
|
||||
|
||||
for (field, value) in pairs {
|
||||
// Check if field already exists
|
||||
let exists = hashes_table.get((key, field.as_str()))?.is_some();
|
||||
|
||||
// Encrypt the value before storing
|
||||
let encrypted = self.encrypt_if_needed(value.as_bytes())?;
|
||||
hashes_table.insert((key, field.as_str()), encrypted.as_slice())?;
|
||||
|
||||
if !exists {
|
||||
new_fields += 1;
|
||||
}
|
||||
}
|
||||
}
|
||||
Some(_) => return Err(DBError("WRONGTYPE Operation against a key holding the wrong kind of value".to_string())),
|
||||
}
|
||||
}
|
||||
|
||||
write_txn.commit()?;
|
||||
Ok(new_fields)
|
||||
}
|
||||
|
||||
// ✅ ENCRYPTION APPLIED: Value is decrypted after retrieval
|
||||
pub fn hget(&self, key: &str, field: &str) -> Result<Option<String>, DBError> {
|
||||
let read_txn = self.db.begin_read()?;
|
||||
let types_table = read_txn.open_table(TYPES_TABLE)?;
|
||||
|
||||
let key_type = types_table.get(key)?.map(|v| v.value().to_string());
|
||||
|
||||
match key_type.as_deref() {
|
||||
Some("hash") => {
|
||||
let hashes_table = read_txn.open_table(HASHES_TABLE)?;
|
||||
match hashes_table.get((key, field))? {
|
||||
Some(data) => {
|
||||
let decrypted = self.decrypt_if_needed(data.value())?;
|
||||
let value = String::from_utf8(decrypted)?;
|
||||
Ok(Some(value))
|
||||
}
|
||||
None => Ok(None),
|
||||
}
|
||||
}
|
||||
Some(_) => Err(DBError("WRONGTYPE Operation against a key holding the wrong kind of value".to_string())),
|
||||
None => Ok(None),
|
||||
}
|
||||
}
|
||||
|
||||
// ✅ ENCRYPTION APPLIED: All values are decrypted after retrieval
|
||||
pub fn hgetall(&self, key: &str) -> Result<Vec<(String, String)>, DBError> {
|
||||
let read_txn = self.db.begin_read()?;
|
||||
let types_table = read_txn.open_table(TYPES_TABLE)?;
|
||||
let key_type = {
|
||||
let access_guard = types_table.get(key)?;
|
||||
access_guard.map(|v| v.value().to_string())
|
||||
};
|
||||
|
||||
match key_type.as_deref() {
|
||||
Some("hash") => {
|
||||
let hashes_table = read_txn.open_table(HASHES_TABLE)?;
|
||||
let mut result = Vec::new();
|
||||
|
||||
let mut iter = hashes_table.iter()?;
|
||||
while let Some(entry) = iter.next() {
|
||||
let entry = entry?;
|
||||
let (hash_key, field) = entry.0.value();
|
||||
if hash_key == key {
|
||||
let decrypted = self.decrypt_if_needed(entry.1.value())?;
|
||||
let value = String::from_utf8(decrypted)?;
|
||||
result.push((field.to_string(), value));
|
||||
}
|
||||
}
|
||||
|
||||
Ok(result)
|
||||
}
|
||||
Some(_) => Err(DBError("WRONGTYPE Operation against a key holding the wrong kind of value".to_string())),
|
||||
None => Ok(Vec::new()),
|
||||
}
|
||||
}
|
||||
|
||||
pub fn hdel(&self, key: &str, fields: Vec<String>) -> Result<i64, DBError> {
|
||||
let write_txn = self.db.begin_write()?;
|
||||
let mut deleted = 0i64;
|
||||
|
||||
// First check if key exists and is a hash
|
||||
let key_type = {
|
||||
let types_table = write_txn.open_table(TYPES_TABLE)?;
|
||||
let access_guard = types_table.get(key)?;
|
||||
access_guard.map(|v| v.value().to_string())
|
||||
};
|
||||
|
||||
match key_type.as_deref() {
|
||||
Some("hash") => {
|
||||
let mut hashes_table = write_txn.open_table(HASHES_TABLE)?;
|
||||
|
||||
for field in fields {
|
||||
if hashes_table.remove((key, field.as_str()))?.is_some() {
|
||||
deleted += 1;
|
||||
}
|
||||
}
|
||||
|
||||
// Check if hash is now empty and remove type if so
|
||||
let mut has_fields = false;
|
||||
let mut iter = hashes_table.iter()?;
|
||||
while let Some(entry) = iter.next() {
|
||||
let entry = entry?;
|
||||
let (hash_key, _) = entry.0.value();
|
||||
if hash_key == key {
|
||||
has_fields = true;
|
||||
break;
|
||||
}
|
||||
}
|
||||
drop(iter);
|
||||
|
||||
if !has_fields {
|
||||
let mut types_table = write_txn.open_table(TYPES_TABLE)?;
|
||||
types_table.remove(key)?;
|
||||
}
|
||||
}
|
||||
Some(_) => return Err(DBError("WRONGTYPE Operation against a key holding the wrong kind of value".to_string())),
|
||||
None => {} // Key does not exist, nothing to delete, return 0 deleted
|
||||
}
|
||||
|
||||
write_txn.commit()?;
|
||||
Ok(deleted)
|
||||
}
|
||||
|
||||
pub fn hexists(&self, key: &str, field: &str) -> Result<bool, DBError> {
|
||||
let read_txn = self.db.begin_read()?;
|
||||
let types_table = read_txn.open_table(TYPES_TABLE)?;
|
||||
let key_type = {
|
||||
let access_guard = types_table.get(key)?;
|
||||
access_guard.map(|v| v.value().to_string())
|
||||
};
|
||||
|
||||
match key_type.as_deref() {
|
||||
Some("hash") => {
|
||||
let hashes_table = read_txn.open_table(HASHES_TABLE)?;
|
||||
Ok(hashes_table.get((key, field))?.is_some())
|
||||
}
|
||||
Some(_) => Err(DBError("WRONGTYPE Operation against a key holding the wrong kind of value".to_string())),
|
||||
None => Ok(false),
|
||||
}
|
||||
}
|
||||
|
||||
pub fn hkeys(&self, key: &str) -> Result<Vec<String>, DBError> {
|
||||
let read_txn = self.db.begin_read()?;
|
||||
let types_table = read_txn.open_table(TYPES_TABLE)?;
|
||||
let key_type = {
|
||||
let access_guard = types_table.get(key)?;
|
||||
access_guard.map(|v| v.value().to_string())
|
||||
};
|
||||
|
||||
match key_type.as_deref() {
|
||||
Some("hash") => {
|
||||
let hashes_table = read_txn.open_table(HASHES_TABLE)?;
|
||||
let mut result = Vec::new();
|
||||
|
||||
let mut iter = hashes_table.iter()?;
|
||||
while let Some(entry) = iter.next() {
|
||||
let entry = entry?;
|
||||
let (hash_key, field) = entry.0.value();
|
||||
if hash_key == key {
|
||||
result.push(field.to_string());
|
||||
}
|
||||
}
|
||||
|
||||
Ok(result)
|
||||
}
|
||||
Some(_) => Err(DBError("WRONGTYPE Operation against a key holding the wrong kind of value".to_string())),
|
||||
None => Ok(Vec::new()),
|
||||
}
|
||||
}
|
||||
|
||||
// ✅ ENCRYPTION APPLIED: All values are decrypted after retrieval
|
||||
pub fn hvals(&self, key: &str) -> Result<Vec<String>, DBError> {
|
||||
let read_txn = self.db.begin_read()?;
|
||||
let types_table = read_txn.open_table(TYPES_TABLE)?;
|
||||
let key_type = {
|
||||
let access_guard = types_table.get(key)?;
|
||||
access_guard.map(|v| v.value().to_string())
|
||||
};
|
||||
|
||||
match key_type.as_deref() {
|
||||
Some("hash") => {
|
||||
let hashes_table = read_txn.open_table(HASHES_TABLE)?;
|
||||
let mut result = Vec::new();
|
||||
|
||||
let mut iter = hashes_table.iter()?;
|
||||
while let Some(entry) = iter.next() {
|
||||
let entry = entry?;
|
||||
let (hash_key, _) = entry.0.value();
|
||||
if hash_key == key {
|
||||
let decrypted = self.decrypt_if_needed(entry.1.value())?;
|
||||
let value = String::from_utf8(decrypted)?;
|
||||
result.push(value);
|
||||
}
|
||||
}
|
||||
|
||||
Ok(result)
|
||||
}
|
||||
Some(_) => Err(DBError("WRONGTYPE Operation against a key holding the wrong kind of value".to_string())),
|
||||
None => Ok(Vec::new()),
|
||||
}
|
||||
}
|
||||
|
||||
pub fn hlen(&self, key: &str) -> Result<i64, DBError> {
|
||||
let read_txn = self.db.begin_read()?;
|
||||
let types_table = read_txn.open_table(TYPES_TABLE)?;
|
||||
let key_type = {
|
||||
let access_guard = types_table.get(key)?;
|
||||
access_guard.map(|v| v.value().to_string())
|
||||
};
|
||||
|
||||
match key_type.as_deref() {
|
||||
Some("hash") => {
|
||||
let hashes_table = read_txn.open_table(HASHES_TABLE)?;
|
||||
let mut count = 0i64;
|
||||
|
||||
let mut iter = hashes_table.iter()?;
|
||||
while let Some(entry) = iter.next() {
|
||||
let entry = entry?;
|
||||
let (hash_key, _) = entry.0.value();
|
||||
if hash_key == key {
|
||||
count += 1;
|
||||
}
|
||||
}
|
||||
|
||||
Ok(count)
|
||||
}
|
||||
Some(_) => Err(DBError("WRONGTYPE Operation against a key holding the wrong kind of value".to_string())),
|
||||
None => Ok(0),
|
||||
}
|
||||
}
|
||||
|
||||
// ✅ ENCRYPTION APPLIED: Values are decrypted after retrieval
|
||||
pub fn hmget(&self, key: &str, fields: Vec<String>) -> Result<Vec<Option<String>>, DBError> {
|
||||
let read_txn = self.db.begin_read()?;
|
||||
let types_table = read_txn.open_table(TYPES_TABLE)?;
|
||||
let key_type = {
|
||||
let access_guard = types_table.get(key)?;
|
||||
access_guard.map(|v| v.value().to_string())
|
||||
};
|
||||
|
||||
match key_type.as_deref() {
|
||||
Some("hash") => {
|
||||
let hashes_table = read_txn.open_table(HASHES_TABLE)?;
|
||||
let mut result = Vec::new();
|
||||
|
||||
for field in fields {
|
||||
match hashes_table.get((key, field.as_str()))? {
|
||||
Some(data) => {
|
||||
let decrypted = self.decrypt_if_needed(data.value())?;
|
||||
let value = String::from_utf8(decrypted)?;
|
||||
result.push(Some(value));
|
||||
}
|
||||
None => result.push(None),
|
||||
}
|
||||
}
|
||||
|
||||
Ok(result)
|
||||
}
|
||||
Some(_) => Err(DBError("WRONGTYPE Operation against a key holding the wrong kind of value".to_string())),
|
||||
None => Ok(fields.into_iter().map(|_| None).collect()),
|
||||
}
|
||||
}
|
||||
|
||||
// ✅ ENCRYPTION APPLIED: Value is encrypted before storage
|
||||
pub fn hsetnx(&self, key: &str, field: &str, value: &str) -> Result<bool, DBError> {
|
||||
let write_txn = self.db.begin_write()?;
|
||||
let mut result = false;
|
||||
|
||||
{
|
||||
let mut types_table = write_txn.open_table(TYPES_TABLE)?;
|
||||
let mut hashes_table = write_txn.open_table(HASHES_TABLE)?;
|
||||
|
||||
let key_type = {
|
||||
let access_guard = types_table.get(key)?;
|
||||
access_guard.map(|v| v.value().to_string())
|
||||
};
|
||||
|
||||
match key_type.as_deref() {
|
||||
Some("hash") | None => { // Proceed if hash or new key
|
||||
// Check if field already exists
|
||||
if hashes_table.get((key, field))?.is_none() {
|
||||
// Set the type to hash (only if new key or existing hash)
|
||||
types_table.insert(key, "hash")?;
|
||||
|
||||
// Encrypt the value before storing
|
||||
let encrypted = self.encrypt_if_needed(value.as_bytes())?;
|
||||
hashes_table.insert((key, field), encrypted.as_slice())?;
|
||||
result = true;
|
||||
}
|
||||
}
|
||||
Some(_) => return Err(DBError("WRONGTYPE Operation against a key holding the wrong kind of value".to_string())),
|
||||
}
|
||||
}
|
||||
|
||||
write_txn.commit()?;
|
||||
Ok(result)
|
||||
}
|
||||
|
||||
// ✅ ENCRYPTION APPLIED: Values are decrypted after retrieval
|
||||
pub fn hscan(&self, key: &str, cursor: u64, pattern: Option<&str>, count: Option<u64>) -> Result<(u64, Vec<(String, String)>), DBError> {
|
||||
let read_txn = self.db.begin_read()?;
|
||||
let types_table = read_txn.open_table(TYPES_TABLE)?;
|
||||
let key_type = {
|
||||
let access_guard = types_table.get(key)?;
|
||||
access_guard.map(|v| v.value().to_string())
|
||||
};
|
||||
|
||||
match key_type.as_deref() {
|
||||
Some("hash") => {
|
||||
let hashes_table = read_txn.open_table(HASHES_TABLE)?;
|
||||
let mut result = Vec::new();
|
||||
let mut current_cursor = 0u64;
|
||||
let limit = count.unwrap_or(10) as usize;
|
||||
|
||||
let mut iter = hashes_table.iter()?;
|
||||
while let Some(entry) = iter.next() {
|
||||
let entry = entry?;
|
||||
let (hash_key, field) = entry.0.value();
|
||||
|
||||
if hash_key == key {
|
||||
if current_cursor >= cursor {
|
||||
let field_str = field.to_string();
|
||||
|
||||
// Apply pattern matching if specified
|
||||
let matches = if let Some(pat) = pattern {
|
||||
super::storage_extra::glob_match(pat, &field_str)
|
||||
} else {
|
||||
true
|
||||
};
|
||||
|
||||
if matches {
|
||||
let decrypted = self.decrypt_if_needed(entry.1.value())?;
|
||||
let value = String::from_utf8(decrypted)?;
|
||||
result.push((field_str, value));
|
||||
|
||||
if result.len() >= limit {
|
||||
break;
|
||||
}
|
||||
}
|
||||
}
|
||||
current_cursor += 1;
|
||||
}
|
||||
}
|
||||
|
||||
let next_cursor = if result.len() < limit { 0 } else { current_cursor };
|
||||
Ok((next_cursor, result))
|
||||
}
|
||||
Some(_) => Err(DBError("WRONGTYPE Operation against a key holding the wrong kind of value".to_string())),
|
||||
None => Ok((0, Vec::new())),
|
||||
}
|
||||
}
|
||||
}
|
403
src/storage/storage_lists.rs
Normal file
403
src/storage/storage_lists.rs
Normal file
@@ -0,0 +1,403 @@
|
||||
use redb::{ReadableTable};
|
||||
use crate::error::DBError;
|
||||
use super::*;
|
||||
|
||||
impl Storage {
|
||||
// ✅ ENCRYPTION APPLIED: Elements are encrypted before storage
|
||||
pub fn lpush(&self, key: &str, elements: Vec<String>) -> Result<i64, DBError> {
|
||||
let write_txn = self.db.begin_write()?;
|
||||
let mut _length = 0i64;
|
||||
|
||||
{
|
||||
let mut types_table = write_txn.open_table(TYPES_TABLE)?;
|
||||
let mut lists_table = write_txn.open_table(LISTS_TABLE)?;
|
||||
|
||||
// Set the type to list
|
||||
types_table.insert(key, "list")?;
|
||||
|
||||
// Get current list or create empty one
|
||||
let mut list: Vec<String> = match lists_table.get(key)? {
|
||||
Some(data) => {
|
||||
let decrypted = self.decrypt_if_needed(data.value())?;
|
||||
serde_json::from_slice(&decrypted)?
|
||||
}
|
||||
None => Vec::new(),
|
||||
};
|
||||
|
||||
// Add elements to the front (left)
|
||||
for element in elements.into_iter() {
|
||||
list.insert(0, element);
|
||||
}
|
||||
|
||||
_length = list.len() as i64;
|
||||
|
||||
// Encrypt and store the updated list
|
||||
let serialized = serde_json::to_vec(&list)?;
|
||||
let encrypted = self.encrypt_if_needed(&serialized)?;
|
||||
lists_table.insert(key, encrypted.as_slice())?;
|
||||
}
|
||||
|
||||
write_txn.commit()?;
|
||||
Ok(_length)
|
||||
}
|
||||
|
||||
// ✅ ENCRYPTION APPLIED: Elements are encrypted before storage
|
||||
pub fn rpush(&self, key: &str, elements: Vec<String>) -> Result<i64, DBError> {
|
||||
let write_txn = self.db.begin_write()?;
|
||||
let mut _length = 0i64;
|
||||
|
||||
{
|
||||
let mut types_table = write_txn.open_table(TYPES_TABLE)?;
|
||||
let mut lists_table = write_txn.open_table(LISTS_TABLE)?;
|
||||
|
||||
// Set the type to list
|
||||
types_table.insert(key, "list")?;
|
||||
|
||||
// Get current list or create empty one
|
||||
let mut list: Vec<String> = match lists_table.get(key)? {
|
||||
Some(data) => {
|
||||
let decrypted = self.decrypt_if_needed(data.value())?;
|
||||
serde_json::from_slice(&decrypted)?
|
||||
}
|
||||
None => Vec::new(),
|
||||
};
|
||||
|
||||
// Add elements to the end (right)
|
||||
list.extend(elements);
|
||||
_length = list.len() as i64;
|
||||
|
||||
// Encrypt and store the updated list
|
||||
let serialized = serde_json::to_vec(&list)?;
|
||||
let encrypted = self.encrypt_if_needed(&serialized)?;
|
||||
lists_table.insert(key, encrypted.as_slice())?;
|
||||
}
|
||||
|
||||
write_txn.commit()?;
|
||||
Ok(_length)
|
||||
}
|
||||
|
||||
// ✅ ENCRYPTION APPLIED: Elements are decrypted after retrieval and encrypted before storage
|
||||
pub fn lpop(&self, key: &str, count: u64) -> Result<Vec<String>, DBError> {
|
||||
let write_txn = self.db.begin_write()?;
|
||||
let mut result = Vec::new();
|
||||
|
||||
// First check if key exists and is a list, and get the data
|
||||
let list_data = {
|
||||
let types_table = write_txn.open_table(TYPES_TABLE)?;
|
||||
let lists_table = write_txn.open_table(LISTS_TABLE)?;
|
||||
|
||||
let result = match types_table.get(key)? {
|
||||
Some(type_val) if type_val.value() == "list" => {
|
||||
if let Some(data) = lists_table.get(key)? {
|
||||
let decrypted = self.decrypt_if_needed(data.value())?;
|
||||
let list: Vec<String> = serde_json::from_slice(&decrypted)?;
|
||||
Some(list)
|
||||
} else {
|
||||
None
|
||||
}
|
||||
}
|
||||
_ => None,
|
||||
};
|
||||
result
|
||||
};
|
||||
|
||||
if let Some(mut list) = list_data {
|
||||
let pop_count = std::cmp::min(count as usize, list.len());
|
||||
for _ in 0..pop_count {
|
||||
if !list.is_empty() {
|
||||
result.push(list.remove(0));
|
||||
}
|
||||
}
|
||||
|
||||
let mut lists_table = write_txn.open_table(LISTS_TABLE)?;
|
||||
if list.is_empty() {
|
||||
// Remove the key if list is empty
|
||||
lists_table.remove(key)?;
|
||||
let mut types_table = write_txn.open_table(TYPES_TABLE)?;
|
||||
types_table.remove(key)?;
|
||||
} else {
|
||||
// Encrypt and store the updated list
|
||||
let serialized = serde_json::to_vec(&list)?;
|
||||
let encrypted = self.encrypt_if_needed(&serialized)?;
|
||||
lists_table.insert(key, encrypted.as_slice())?;
|
||||
}
|
||||
}
|
||||
|
||||
write_txn.commit()?;
|
||||
Ok(result)
|
||||
}
|
||||
|
||||
// ✅ ENCRYPTION APPLIED: Elements are decrypted after retrieval and encrypted before storage
|
||||
pub fn rpop(&self, key: &str, count: u64) -> Result<Vec<String>, DBError> {
|
||||
let write_txn = self.db.begin_write()?;
|
||||
let mut result = Vec::new();
|
||||
|
||||
// First check if key exists and is a list, and get the data
|
||||
let list_data = {
|
||||
let types_table = write_txn.open_table(TYPES_TABLE)?;
|
||||
let lists_table = write_txn.open_table(LISTS_TABLE)?;
|
||||
|
||||
let result = match types_table.get(key)? {
|
||||
Some(type_val) if type_val.value() == "list" => {
|
||||
if let Some(data) = lists_table.get(key)? {
|
||||
let decrypted = self.decrypt_if_needed(data.value())?;
|
||||
let list: Vec<String> = serde_json::from_slice(&decrypted)?;
|
||||
Some(list)
|
||||
} else {
|
||||
None
|
||||
}
|
||||
}
|
||||
_ => None,
|
||||
};
|
||||
result
|
||||
};
|
||||
|
||||
if let Some(mut list) = list_data {
|
||||
let pop_count = std::cmp::min(count as usize, list.len());
|
||||
for _ in 0..pop_count {
|
||||
if !list.is_empty() {
|
||||
result.push(list.pop().unwrap());
|
||||
}
|
||||
}
|
||||
|
||||
let mut lists_table = write_txn.open_table(LISTS_TABLE)?;
|
||||
if list.is_empty() {
|
||||
// Remove the key if list is empty
|
||||
lists_table.remove(key)?;
|
||||
let mut types_table = write_txn.open_table(TYPES_TABLE)?;
|
||||
types_table.remove(key)?;
|
||||
} else {
|
||||
// Encrypt and store the updated list
|
||||
let serialized = serde_json::to_vec(&list)?;
|
||||
let encrypted = self.encrypt_if_needed(&serialized)?;
|
||||
lists_table.insert(key, encrypted.as_slice())?;
|
||||
}
|
||||
}
|
||||
|
||||
write_txn.commit()?;
|
||||
Ok(result)
|
||||
}
|
||||
|
||||
pub fn llen(&self, key: &str) -> Result<i64, DBError> {
|
||||
let read_txn = self.db.begin_read()?;
|
||||
let types_table = read_txn.open_table(TYPES_TABLE)?;
|
||||
|
||||
match types_table.get(key)? {
|
||||
Some(type_val) if type_val.value() == "list" => {
|
||||
let lists_table = read_txn.open_table(LISTS_TABLE)?;
|
||||
match lists_table.get(key)? {
|
||||
Some(data) => {
|
||||
let decrypted = self.decrypt_if_needed(data.value())?;
|
||||
let list: Vec<String> = serde_json::from_slice(&decrypted)?;
|
||||
Ok(list.len() as i64)
|
||||
}
|
||||
None => Ok(0),
|
||||
}
|
||||
}
|
||||
_ => Ok(0),
|
||||
}
|
||||
}
|
||||
|
||||
// ✅ ENCRYPTION APPLIED: Element is decrypted after retrieval
|
||||
pub fn lindex(&self, key: &str, index: i64) -> Result<Option<String>, DBError> {
|
||||
let read_txn = self.db.begin_read()?;
|
||||
let types_table = read_txn.open_table(TYPES_TABLE)?;
|
||||
|
||||
match types_table.get(key)? {
|
||||
Some(type_val) if type_val.value() == "list" => {
|
||||
let lists_table = read_txn.open_table(LISTS_TABLE)?;
|
||||
match lists_table.get(key)? {
|
||||
Some(data) => {
|
||||
let decrypted = self.decrypt_if_needed(data.value())?;
|
||||
let list: Vec<String> = serde_json::from_slice(&decrypted)?;
|
||||
|
||||
let actual_index = if index < 0 {
|
||||
list.len() as i64 + index
|
||||
} else {
|
||||
index
|
||||
};
|
||||
|
||||
if actual_index >= 0 && (actual_index as usize) < list.len() {
|
||||
Ok(Some(list[actual_index as usize].clone()))
|
||||
} else {
|
||||
Ok(None)
|
||||
}
|
||||
}
|
||||
None => Ok(None),
|
||||
}
|
||||
}
|
||||
_ => Ok(None),
|
||||
}
|
||||
}
|
||||
|
||||
// ✅ ENCRYPTION APPLIED: Elements are decrypted after retrieval
|
||||
pub fn lrange(&self, key: &str, start: i64, stop: i64) -> Result<Vec<String>, DBError> {
|
||||
let read_txn = self.db.begin_read()?;
|
||||
let types_table = read_txn.open_table(TYPES_TABLE)?;
|
||||
|
||||
match types_table.get(key)? {
|
||||
Some(type_val) if type_val.value() == "list" => {
|
||||
let lists_table = read_txn.open_table(LISTS_TABLE)?;
|
||||
match lists_table.get(key)? {
|
||||
Some(data) => {
|
||||
let decrypted = self.decrypt_if_needed(data.value())?;
|
||||
let list: Vec<String> = serde_json::from_slice(&decrypted)?;
|
||||
|
||||
if list.is_empty() {
|
||||
return Ok(Vec::new());
|
||||
}
|
||||
|
||||
let len = list.len() as i64;
|
||||
let start_idx = if start < 0 { std::cmp::max(0, len + start) } else { std::cmp::min(start, len) };
|
||||
let stop_idx = if stop < 0 { std::cmp::max(-1, len + stop) } else { std::cmp::min(stop, len - 1) };
|
||||
|
||||
if start_idx > stop_idx || start_idx >= len {
|
||||
return Ok(Vec::new());
|
||||
}
|
||||
|
||||
let start_usize = start_idx as usize;
|
||||
let stop_usize = (stop_idx + 1) as usize;
|
||||
|
||||
Ok(list[start_usize..std::cmp::min(stop_usize, list.len())].to_vec())
|
||||
}
|
||||
None => Ok(Vec::new()),
|
||||
}
|
||||
}
|
||||
_ => Ok(Vec::new()),
|
||||
}
|
||||
}
|
||||
|
||||
// ✅ ENCRYPTION APPLIED: Elements are decrypted after retrieval and encrypted before storage
|
||||
pub fn ltrim(&self, key: &str, start: i64, stop: i64) -> Result<(), DBError> {
|
||||
let write_txn = self.db.begin_write()?;
|
||||
|
||||
// First check if key exists and is a list, and get the data
|
||||
let list_data = {
|
||||
let types_table = write_txn.open_table(TYPES_TABLE)?;
|
||||
let lists_table = write_txn.open_table(LISTS_TABLE)?;
|
||||
|
||||
let result = match types_table.get(key)? {
|
||||
Some(type_val) if type_val.value() == "list" => {
|
||||
if let Some(data) = lists_table.get(key)? {
|
||||
let decrypted = self.decrypt_if_needed(data.value())?;
|
||||
let list: Vec<String> = serde_json::from_slice(&decrypted)?;
|
||||
Some(list)
|
||||
} else {
|
||||
None
|
||||
}
|
||||
}
|
||||
_ => None,
|
||||
};
|
||||
result
|
||||
};
|
||||
|
||||
if let Some(list) = list_data {
|
||||
if list.is_empty() {
|
||||
write_txn.commit()?;
|
||||
return Ok(());
|
||||
}
|
||||
|
||||
let len = list.len() as i64;
|
||||
let start_idx = if start < 0 { std::cmp::max(0, len + start) } else { std::cmp::min(start, len) };
|
||||
let stop_idx = if stop < 0 { std::cmp::max(-1, len + stop) } else { std::cmp::min(stop, len - 1) };
|
||||
|
||||
let mut lists_table = write_txn.open_table(LISTS_TABLE)?;
|
||||
if start_idx > stop_idx || start_idx >= len {
|
||||
// Remove the entire list
|
||||
lists_table.remove(key)?;
|
||||
let mut types_table = write_txn.open_table(TYPES_TABLE)?;
|
||||
types_table.remove(key)?;
|
||||
} else {
|
||||
let start_usize = start_idx as usize;
|
||||
let stop_usize = (stop_idx + 1) as usize;
|
||||
let trimmed = list[start_usize..std::cmp::min(stop_usize, list.len())].to_vec();
|
||||
|
||||
if trimmed.is_empty() {
|
||||
lists_table.remove(key)?;
|
||||
let mut types_table = write_txn.open_table(TYPES_TABLE)?;
|
||||
types_table.remove(key)?;
|
||||
} else {
|
||||
// Encrypt and store the trimmed list
|
||||
let serialized = serde_json::to_vec(&trimmed)?;
|
||||
let encrypted = self.encrypt_if_needed(&serialized)?;
|
||||
lists_table.insert(key, encrypted.as_slice())?;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
write_txn.commit()?;
|
||||
Ok(())
|
||||
}
|
||||
|
||||
// ✅ ENCRYPTION APPLIED: Elements are decrypted after retrieval and encrypted before storage
|
||||
pub fn lrem(&self, key: &str, count: i64, element: &str) -> Result<i64, DBError> {
|
||||
let write_txn = self.db.begin_write()?;
|
||||
let mut removed = 0i64;
|
||||
|
||||
// First check if key exists and is a list, and get the data
|
||||
let list_data = {
|
||||
let types_table = write_txn.open_table(TYPES_TABLE)?;
|
||||
let lists_table = write_txn.open_table(LISTS_TABLE)?;
|
||||
|
||||
let result = match types_table.get(key)? {
|
||||
Some(type_val) if type_val.value() == "list" => {
|
||||
if let Some(data) = lists_table.get(key)? {
|
||||
let decrypted = self.decrypt_if_needed(data.value())?;
|
||||
let list: Vec<String> = serde_json::from_slice(&decrypted)?;
|
||||
Some(list)
|
||||
} else {
|
||||
None
|
||||
}
|
||||
}
|
||||
_ => None,
|
||||
};
|
||||
result
|
||||
};
|
||||
|
||||
if let Some(mut list) = list_data {
|
||||
if count == 0 {
|
||||
// Remove all occurrences
|
||||
let original_len = list.len();
|
||||
list.retain(|x| x != element);
|
||||
removed = (original_len - list.len()) as i64;
|
||||
} else if count > 0 {
|
||||
// Remove first count occurrences
|
||||
let mut to_remove = count as usize;
|
||||
list.retain(|x| {
|
||||
if x == element && to_remove > 0 {
|
||||
to_remove -= 1;
|
||||
removed += 1;
|
||||
false
|
||||
} else {
|
||||
true
|
||||
}
|
||||
});
|
||||
} else {
|
||||
// Remove last |count| occurrences
|
||||
let mut to_remove = (-count) as usize;
|
||||
for i in (0..list.len()).rev() {
|
||||
if list[i] == element && to_remove > 0 {
|
||||
list.remove(i);
|
||||
to_remove -= 1;
|
||||
removed += 1;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
let mut lists_table = write_txn.open_table(LISTS_TABLE)?;
|
||||
if list.is_empty() {
|
||||
lists_table.remove(key)?;
|
||||
let mut types_table = write_txn.open_table(TYPES_TABLE)?;
|
||||
types_table.remove(key)?;
|
||||
} else {
|
||||
// Encrypt and store the updated list
|
||||
let serialized = serde_json::to_vec(&list)?;
|
||||
let encrypted = self.encrypt_if_needed(&serialized)?;
|
||||
lists_table.insert(key, encrypted.as_slice())?;
|
||||
}
|
||||
}
|
||||
|
||||
write_txn.commit()?;
|
||||
Ok(removed)
|
||||
}
|
||||
}
|
845
src/storage_sled/mod.rs
Normal file
845
src/storage_sled/mod.rs
Normal file
@@ -0,0 +1,845 @@
|
||||
// src/storage_sled/mod.rs
|
||||
use std::path::Path;
|
||||
use std::sync::Arc;
|
||||
use std::collections::HashMap;
|
||||
use std::time::{SystemTime, UNIX_EPOCH};
|
||||
use serde::{Deserialize, Serialize};
|
||||
use crate::error::DBError;
|
||||
use crate::storage_trait::StorageBackend;
|
||||
use crate::crypto::CryptoFactory;
|
||||
|
||||
#[derive(Serialize, Deserialize, Debug, Clone)]
|
||||
enum ValueType {
|
||||
String(String),
|
||||
Hash(HashMap<String, String>),
|
||||
List(Vec<String>),
|
||||
}
|
||||
|
||||
#[derive(Serialize, Deserialize, Debug, Clone)]
|
||||
struct StorageValue {
|
||||
value: ValueType,
|
||||
expires_at: Option<u128>, // milliseconds since epoch
|
||||
}
|
||||
|
||||
pub struct SledStorage {
|
||||
db: sled::Db,
|
||||
types: sled::Tree,
|
||||
crypto: Option<CryptoFactory>,
|
||||
}
|
||||
|
||||
impl SledStorage {
|
||||
pub fn new(path: impl AsRef<Path>, should_encrypt: bool, master_key: Option<&str>) -> Result<Self, DBError> {
|
||||
let db = sled::open(path).map_err(|e| DBError(format!("Failed to open sled: {}", e)))?;
|
||||
let types = db.open_tree("types").map_err(|e| DBError(format!("Failed to open types tree: {}", e)))?;
|
||||
|
||||
// Check if database was previously encrypted
|
||||
let encrypted_tree = db.open_tree("encrypted").map_err(|e| DBError(e.to_string()))?;
|
||||
let was_encrypted = encrypted_tree.get("encrypted")
|
||||
.map_err(|e| DBError(e.to_string()))?
|
||||
.map(|v| v[0] == 1)
|
||||
.unwrap_or(false);
|
||||
|
||||
let crypto = if should_encrypt || was_encrypted {
|
||||
if let Some(key) = master_key {
|
||||
Some(CryptoFactory::new(key.as_bytes()))
|
||||
} else {
|
||||
return Err(DBError("Encryption requested but no master key provided".to_string()));
|
||||
}
|
||||
} else {
|
||||
None
|
||||
};
|
||||
|
||||
// Mark database as encrypted if enabling encryption
|
||||
if should_encrypt && !was_encrypted {
|
||||
encrypted_tree.insert("encrypted", &[1u8])
|
||||
.map_err(|e| DBError(e.to_string()))?;
|
||||
encrypted_tree.flush().map_err(|e| DBError(e.to_string()))?;
|
||||
}
|
||||
|
||||
Ok(SledStorage { db, types, crypto })
|
||||
}
|
||||
|
||||
fn now_millis() -> u128 {
|
||||
SystemTime::now()
|
||||
.duration_since(UNIX_EPOCH)
|
||||
.unwrap()
|
||||
.as_millis()
|
||||
}
|
||||
|
||||
fn encrypt_if_needed(&self, data: &[u8]) -> Result<Vec<u8>, DBError> {
|
||||
if let Some(crypto) = &self.crypto {
|
||||
Ok(crypto.encrypt(data))
|
||||
} else {
|
||||
Ok(data.to_vec())
|
||||
}
|
||||
}
|
||||
|
||||
fn decrypt_if_needed(&self, data: &[u8]) -> Result<Vec<u8>, DBError> {
|
||||
if let Some(crypto) = &self.crypto {
|
||||
Ok(crypto.decrypt(data)?)
|
||||
} else {
|
||||
Ok(data.to_vec())
|
||||
}
|
||||
}
|
||||
|
||||
fn get_storage_value(&self, key: &str) -> Result<Option<StorageValue>, DBError> {
|
||||
match self.db.get(key).map_err(|e| DBError(e.to_string()))? {
|
||||
Some(encrypted_data) => {
|
||||
let decrypted = self.decrypt_if_needed(&encrypted_data)?;
|
||||
let storage_val: StorageValue = bincode::deserialize(&decrypted)
|
||||
.map_err(|e| DBError(format!("Deserialization error: {}", e)))?;
|
||||
|
||||
// Check expiration
|
||||
if let Some(expires_at) = storage_val.expires_at {
|
||||
if Self::now_millis() > expires_at {
|
||||
// Expired, remove it
|
||||
self.db.remove(key).map_err(|e| DBError(e.to_string()))?;
|
||||
self.types.remove(key).map_err(|e| DBError(e.to_string()))?;
|
||||
return Ok(None);
|
||||
}
|
||||
}
|
||||
|
||||
Ok(Some(storage_val))
|
||||
}
|
||||
None => Ok(None)
|
||||
}
|
||||
}
|
||||
|
||||
fn set_storage_value(&self, key: &str, storage_val: StorageValue) -> Result<(), DBError> {
|
||||
let data = bincode::serialize(&storage_val)
|
||||
.map_err(|e| DBError(format!("Serialization error: {}", e)))?;
|
||||
let encrypted = self.encrypt_if_needed(&data)?;
|
||||
self.db.insert(key, encrypted).map_err(|e| DBError(e.to_string()))?;
|
||||
|
||||
// Store type info (unencrypted for efficiency)
|
||||
let type_str = match &storage_val.value {
|
||||
ValueType::String(_) => "string",
|
||||
ValueType::Hash(_) => "hash",
|
||||
ValueType::List(_) => "list",
|
||||
};
|
||||
self.types.insert(key, type_str.as_bytes()).map_err(|e| DBError(e.to_string()))?;
|
||||
|
||||
Ok(())
|
||||
}
|
||||
|
||||
fn glob_match(pattern: &str, text: &str) -> bool {
|
||||
if pattern == "*" {
|
||||
return true;
|
||||
}
|
||||
|
||||
let pattern_chars: Vec<char> = pattern.chars().collect();
|
||||
let text_chars: Vec<char> = text.chars().collect();
|
||||
|
||||
fn match_recursive(pattern: &[char], text: &[char], pi: usize, ti: usize) -> bool {
|
||||
if pi >= pattern.len() {
|
||||
return ti >= text.len();
|
||||
}
|
||||
|
||||
if ti >= text.len() {
|
||||
return pattern[pi..].iter().all(|&c| c == '*');
|
||||
}
|
||||
|
||||
match pattern[pi] {
|
||||
'*' => {
|
||||
for i in ti..=text.len() {
|
||||
if match_recursive(pattern, text, pi + 1, i) {
|
||||
return true;
|
||||
}
|
||||
}
|
||||
false
|
||||
}
|
||||
'?' => match_recursive(pattern, text, pi + 1, ti + 1),
|
||||
c => {
|
||||
if text[ti] == c {
|
||||
match_recursive(pattern, text, pi + 1, ti + 1)
|
||||
} else {
|
||||
false
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
match_recursive(&pattern_chars, &text_chars, 0, 0)
|
||||
}
|
||||
}
|
||||
|
||||
impl StorageBackend for SledStorage {
|
||||
fn get(&self, key: &str) -> Result<Option<String>, DBError> {
|
||||
match self.get_storage_value(key)? {
|
||||
Some(storage_val) => match storage_val.value {
|
||||
ValueType::String(s) => Ok(Some(s)),
|
||||
_ => Ok(None)
|
||||
}
|
||||
None => Ok(None)
|
||||
}
|
||||
}
|
||||
|
||||
fn set(&self, key: String, value: String) -> Result<(), DBError> {
|
||||
let storage_val = StorageValue {
|
||||
value: ValueType::String(value),
|
||||
expires_at: None,
|
||||
};
|
||||
self.set_storage_value(&key, storage_val)?;
|
||||
self.db.flush().map_err(|e| DBError(e.to_string()))?;
|
||||
Ok(())
|
||||
}
|
||||
|
||||
fn setx(&self, key: String, value: String, expire_ms: u128) -> Result<(), DBError> {
|
||||
let storage_val = StorageValue {
|
||||
value: ValueType::String(value),
|
||||
expires_at: Some(Self::now_millis() + expire_ms),
|
||||
};
|
||||
self.set_storage_value(&key, storage_val)?;
|
||||
self.db.flush().map_err(|e| DBError(e.to_string()))?;
|
||||
Ok(())
|
||||
}
|
||||
|
||||
fn del(&self, key: String) -> Result<(), DBError> {
|
||||
self.db.remove(&key).map_err(|e| DBError(e.to_string()))?;
|
||||
self.types.remove(&key).map_err(|e| DBError(e.to_string()))?;
|
||||
self.db.flush().map_err(|e| DBError(e.to_string()))?;
|
||||
Ok(())
|
||||
}
|
||||
|
||||
fn exists(&self, key: &str) -> Result<bool, DBError> {
|
||||
// Check with expiration
|
||||
Ok(self.get_storage_value(key)?.is_some())
|
||||
}
|
||||
|
||||
fn keys(&self, pattern: &str) -> Result<Vec<String>, DBError> {
|
||||
let mut keys = Vec::new();
|
||||
for item in self.types.iter() {
|
||||
let (key_bytes, _) = item.map_err(|e| DBError(e.to_string()))?;
|
||||
let key = String::from_utf8_lossy(&key_bytes).to_string();
|
||||
|
||||
// Check if key is expired
|
||||
if self.get_storage_value(&key)?.is_some() {
|
||||
if Self::glob_match(pattern, &key) {
|
||||
keys.push(key);
|
||||
}
|
||||
}
|
||||
}
|
||||
Ok(keys)
|
||||
}
|
||||
|
||||
fn scan(&self, cursor: u64, pattern: Option<&str>, count: Option<u64>) -> Result<(u64, Vec<(String, String)>), DBError> {
|
||||
let mut result = Vec::new();
|
||||
let mut current_cursor = 0u64;
|
||||
let limit = count.unwrap_or(10) as usize;
|
||||
|
||||
for item in self.types.iter() {
|
||||
if current_cursor >= cursor {
|
||||
let (key_bytes, type_bytes) = item.map_err(|e| DBError(e.to_string()))?;
|
||||
let key = String::from_utf8_lossy(&key_bytes).to_string();
|
||||
|
||||
// Check pattern match
|
||||
let matches = if let Some(pat) = pattern {
|
||||
Self::glob_match(pat, &key)
|
||||
} else {
|
||||
true
|
||||
};
|
||||
|
||||
if matches {
|
||||
// Check if key is expired and get value
|
||||
if let Some(storage_val) = self.get_storage_value(&key)? {
|
||||
let value = match storage_val.value {
|
||||
ValueType::String(s) => s,
|
||||
_ => String::from_utf8_lossy(&type_bytes).to_string(),
|
||||
};
|
||||
result.push((key, value));
|
||||
|
||||
if result.len() >= limit {
|
||||
current_cursor += 1;
|
||||
break;
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
current_cursor += 1;
|
||||
}
|
||||
|
||||
let next_cursor = if result.len() < limit { 0 } else { current_cursor };
|
||||
Ok((next_cursor, result))
|
||||
}
|
||||
|
||||
fn dbsize(&self) -> Result<i64, DBError> {
|
||||
let mut count = 0i64;
|
||||
for item in self.types.iter() {
|
||||
let (key_bytes, _) = item.map_err(|e| DBError(e.to_string()))?;
|
||||
let key = String::from_utf8_lossy(&key_bytes).to_string();
|
||||
if self.get_storage_value(&key)?.is_some() {
|
||||
count += 1;
|
||||
}
|
||||
}
|
||||
Ok(count)
|
||||
}
|
||||
|
||||
fn flushdb(&self) -> Result<(), DBError> {
|
||||
self.db.clear().map_err(|e| DBError(e.to_string()))?;
|
||||
self.types.clear().map_err(|e| DBError(e.to_string()))?;
|
||||
self.db.flush().map_err(|e| DBError(e.to_string()))?;
|
||||
Ok(())
|
||||
}
|
||||
|
||||
fn get_key_type(&self, key: &str) -> Result<Option<String>, DBError> {
|
||||
// First check if key exists (handles expiration)
|
||||
if self.get_storage_value(key)?.is_some() {
|
||||
match self.types.get(key).map_err(|e| DBError(e.to_string()))? {
|
||||
Some(data) => Ok(Some(String::from_utf8_lossy(&data).to_string())),
|
||||
None => Ok(None)
|
||||
}
|
||||
} else {
|
||||
Ok(None)
|
||||
}
|
||||
}
|
||||
|
||||
// Hash operations
|
||||
fn hset(&self, key: &str, pairs: Vec<(String, String)>) -> Result<i64, DBError> {
|
||||
let mut storage_val = self.get_storage_value(key)?.unwrap_or(StorageValue {
|
||||
value: ValueType::Hash(HashMap::new()),
|
||||
expires_at: None,
|
||||
});
|
||||
|
||||
let hash = match &mut storage_val.value {
|
||||
ValueType::Hash(h) => h,
|
||||
_ => return Err(DBError("WRONGTYPE Operation against a key holding the wrong kind of value".to_string())),
|
||||
};
|
||||
|
||||
let mut new_fields = 0i64;
|
||||
for (field, value) in pairs {
|
||||
if !hash.contains_key(&field) {
|
||||
new_fields += 1;
|
||||
}
|
||||
hash.insert(field, value);
|
||||
}
|
||||
|
||||
self.set_storage_value(key, storage_val)?;
|
||||
self.db.flush().map_err(|e| DBError(e.to_string()))?;
|
||||
Ok(new_fields)
|
||||
}
|
||||
|
||||
fn hget(&self, key: &str, field: &str) -> Result<Option<String>, DBError> {
|
||||
match self.get_storage_value(key)? {
|
||||
Some(storage_val) => match storage_val.value {
|
||||
ValueType::Hash(h) => Ok(h.get(field).cloned()),
|
||||
_ => Ok(None)
|
||||
}
|
||||
None => Ok(None)
|
||||
}
|
||||
}
|
||||
|
||||
fn hgetall(&self, key: &str) -> Result<Vec<(String, String)>, DBError> {
|
||||
match self.get_storage_value(key)? {
|
||||
Some(storage_val) => match storage_val.value {
|
||||
ValueType::Hash(h) => Ok(h.into_iter().collect()),
|
||||
_ => Ok(Vec::new())
|
||||
}
|
||||
None => Ok(Vec::new())
|
||||
}
|
||||
}
|
||||
|
||||
fn hscan(&self, key: &str, cursor: u64, pattern: Option<&str>, count: Option<u64>) -> Result<(u64, Vec<(String, String)>), DBError> {
|
||||
match self.get_storage_value(key)? {
|
||||
Some(storage_val) => match storage_val.value {
|
||||
ValueType::Hash(h) => {
|
||||
let mut result = Vec::new();
|
||||
let mut current_cursor = 0u64;
|
||||
let limit = count.unwrap_or(10) as usize;
|
||||
|
||||
for (field, value) in h.iter() {
|
||||
if current_cursor >= cursor {
|
||||
let matches = if let Some(pat) = pattern {
|
||||
Self::glob_match(pat, field)
|
||||
} else {
|
||||
true
|
||||
};
|
||||
|
||||
if matches {
|
||||
result.push((field.clone(), value.clone()));
|
||||
if result.len() >= limit {
|
||||
current_cursor += 1;
|
||||
break;
|
||||
}
|
||||
}
|
||||
}
|
||||
current_cursor += 1;
|
||||
}
|
||||
|
||||
let next_cursor = if result.len() < limit { 0 } else { current_cursor };
|
||||
Ok((next_cursor, result))
|
||||
}
|
||||
_ => Err(DBError("WRONGTYPE Operation against a key holding the wrong kind of value".to_string()))
|
||||
}
|
||||
None => Ok((0, Vec::new()))
|
||||
}
|
||||
}
|
||||
|
||||
fn hdel(&self, key: &str, fields: Vec<String>) -> Result<i64, DBError> {
|
||||
let mut storage_val = match self.get_storage_value(key)? {
|
||||
Some(sv) => sv,
|
||||
None => return Ok(0)
|
||||
};
|
||||
|
||||
let hash = match &mut storage_val.value {
|
||||
ValueType::Hash(h) => h,
|
||||
_ => return Ok(0)
|
||||
};
|
||||
|
||||
let mut deleted = 0i64;
|
||||
for field in fields {
|
||||
if hash.remove(&field).is_some() {
|
||||
deleted += 1;
|
||||
}
|
||||
}
|
||||
|
||||
if hash.is_empty() {
|
||||
self.del(key.to_string())?;
|
||||
} else {
|
||||
self.set_storage_value(key, storage_val)?;
|
||||
self.db.flush().map_err(|e| DBError(e.to_string()))?;
|
||||
}
|
||||
|
||||
Ok(deleted)
|
||||
}
|
||||
|
||||
fn hexists(&self, key: &str, field: &str) -> Result<bool, DBError> {
|
||||
match self.get_storage_value(key)? {
|
||||
Some(storage_val) => match storage_val.value {
|
||||
ValueType::Hash(h) => Ok(h.contains_key(field)),
|
||||
_ => Ok(false)
|
||||
}
|
||||
None => Ok(false)
|
||||
}
|
||||
}
|
||||
|
||||
fn hkeys(&self, key: &str) -> Result<Vec<String>, DBError> {
|
||||
match self.get_storage_value(key)? {
|
||||
Some(storage_val) => match storage_val.value {
|
||||
ValueType::Hash(h) => Ok(h.keys().cloned().collect()),
|
||||
_ => Ok(Vec::new())
|
||||
}
|
||||
None => Ok(Vec::new())
|
||||
}
|
||||
}
|
||||
|
||||
fn hvals(&self, key: &str) -> Result<Vec<String>, DBError> {
|
||||
match self.get_storage_value(key)? {
|
||||
Some(storage_val) => match storage_val.value {
|
||||
ValueType::Hash(h) => Ok(h.values().cloned().collect()),
|
||||
_ => Ok(Vec::new())
|
||||
}
|
||||
None => Ok(Vec::new())
|
||||
}
|
||||
}
|
||||
|
||||
fn hlen(&self, key: &str) -> Result<i64, DBError> {
|
||||
match self.get_storage_value(key)? {
|
||||
Some(storage_val) => match storage_val.value {
|
||||
ValueType::Hash(h) => Ok(h.len() as i64),
|
||||
_ => Ok(0)
|
||||
}
|
||||
None => Ok(0)
|
||||
}
|
||||
}
|
||||
|
||||
fn hmget(&self, key: &str, fields: Vec<String>) -> Result<Vec<Option<String>>, DBError> {
|
||||
match self.get_storage_value(key)? {
|
||||
Some(storage_val) => match storage_val.value {
|
||||
ValueType::Hash(h) => {
|
||||
Ok(fields.into_iter().map(|f| h.get(&f).cloned()).collect())
|
||||
}
|
||||
_ => Ok(fields.into_iter().map(|_| None).collect())
|
||||
}
|
||||
None => Ok(fields.into_iter().map(|_| None).collect())
|
||||
}
|
||||
}
|
||||
|
||||
fn hsetnx(&self, key: &str, field: &str, value: &str) -> Result<bool, DBError> {
|
||||
let mut storage_val = self.get_storage_value(key)?.unwrap_or(StorageValue {
|
||||
value: ValueType::Hash(HashMap::new()),
|
||||
expires_at: None,
|
||||
});
|
||||
|
||||
let hash = match &mut storage_val.value {
|
||||
ValueType::Hash(h) => h,
|
||||
_ => return Err(DBError("WRONGTYPE Operation against a key holding the wrong kind of value".to_string())),
|
||||
};
|
||||
|
||||
if hash.contains_key(field) {
|
||||
Ok(false)
|
||||
} else {
|
||||
hash.insert(field.to_string(), value.to_string());
|
||||
self.set_storage_value(key, storage_val)?;
|
||||
self.db.flush().map_err(|e| DBError(e.to_string()))?;
|
||||
Ok(true)
|
||||
}
|
||||
}
|
||||
|
||||
// List operations
|
||||
fn lpush(&self, key: &str, elements: Vec<String>) -> Result<i64, DBError> {
|
||||
let mut storage_val = self.get_storage_value(key)?.unwrap_or(StorageValue {
|
||||
value: ValueType::List(Vec::new()),
|
||||
expires_at: None,
|
||||
});
|
||||
|
||||
let list = match &mut storage_val.value {
|
||||
ValueType::List(l) => l,
|
||||
_ => return Err(DBError("WRONGTYPE Operation against a key holding the wrong kind of value".to_string())),
|
||||
};
|
||||
|
||||
for element in elements.into_iter().rev() {
|
||||
list.insert(0, element);
|
||||
}
|
||||
|
||||
let len = list.len() as i64;
|
||||
self.set_storage_value(key, storage_val)?;
|
||||
self.db.flush().map_err(|e| DBError(e.to_string()))?;
|
||||
Ok(len)
|
||||
}
|
||||
|
||||
fn rpush(&self, key: &str, elements: Vec<String>) -> Result<i64, DBError> {
|
||||
let mut storage_val = self.get_storage_value(key)?.unwrap_or(StorageValue {
|
||||
value: ValueType::List(Vec::new()),
|
||||
expires_at: None,
|
||||
});
|
||||
|
||||
let list = match &mut storage_val.value {
|
||||
ValueType::List(l) => l,
|
||||
_ => return Err(DBError("WRONGTYPE Operation against a key holding the wrong kind of value".to_string())),
|
||||
};
|
||||
|
||||
list.extend(elements);
|
||||
let len = list.len() as i64;
|
||||
self.set_storage_value(key, storage_val)?;
|
||||
self.db.flush().map_err(|e| DBError(e.to_string()))?;
|
||||
Ok(len)
|
||||
}
|
||||
|
||||
fn lpop(&self, key: &str, count: u64) -> Result<Vec<String>, DBError> {
|
||||
let mut storage_val = match self.get_storage_value(key)? {
|
||||
Some(sv) => sv,
|
||||
None => return Ok(Vec::new())
|
||||
};
|
||||
|
||||
let list = match &mut storage_val.value {
|
||||
ValueType::List(l) => l,
|
||||
_ => return Ok(Vec::new())
|
||||
};
|
||||
|
||||
let mut result = Vec::new();
|
||||
for _ in 0..count.min(list.len() as u64) {
|
||||
if let Some(elem) = list.first() {
|
||||
result.push(elem.clone());
|
||||
list.remove(0);
|
||||
}
|
||||
}
|
||||
|
||||
if list.is_empty() {
|
||||
self.del(key.to_string())?;
|
||||
} else {
|
||||
self.set_storage_value(key, storage_val)?;
|
||||
self.db.flush().map_err(|e| DBError(e.to_string()))?;
|
||||
}
|
||||
|
||||
Ok(result)
|
||||
}
|
||||
|
||||
fn rpop(&self, key: &str, count: u64) -> Result<Vec<String>, DBError> {
|
||||
let mut storage_val = match self.get_storage_value(key)? {
|
||||
Some(sv) => sv,
|
||||
None => return Ok(Vec::new())
|
||||
};
|
||||
|
||||
let list = match &mut storage_val.value {
|
||||
ValueType::List(l) => l,
|
||||
_ => return Ok(Vec::new())
|
||||
};
|
||||
|
||||
let mut result = Vec::new();
|
||||
for _ in 0..count.min(list.len() as u64) {
|
||||
if let Some(elem) = list.pop() {
|
||||
result.push(elem);
|
||||
}
|
||||
}
|
||||
|
||||
if list.is_empty() {
|
||||
self.del(key.to_string())?;
|
||||
} else {
|
||||
self.set_storage_value(key, storage_val)?;
|
||||
self.db.flush().map_err(|e| DBError(e.to_string()))?;
|
||||
}
|
||||
|
||||
Ok(result)
|
||||
}
|
||||
|
||||
fn llen(&self, key: &str) -> Result<i64, DBError> {
|
||||
match self.get_storage_value(key)? {
|
||||
Some(storage_val) => match storage_val.value {
|
||||
ValueType::List(l) => Ok(l.len() as i64),
|
||||
_ => Ok(0)
|
||||
}
|
||||
None => Ok(0)
|
||||
}
|
||||
}
|
||||
|
||||
fn lindex(&self, key: &str, index: i64) -> Result<Option<String>, DBError> {
|
||||
match self.get_storage_value(key)? {
|
||||
Some(storage_val) => match storage_val.value {
|
||||
ValueType::List(list) => {
|
||||
let actual_index = if index < 0 {
|
||||
list.len() as i64 + index
|
||||
} else {
|
||||
index
|
||||
};
|
||||
|
||||
if actual_index >= 0 && (actual_index as usize) < list.len() {
|
||||
Ok(Some(list[actual_index as usize].clone()))
|
||||
} else {
|
||||
Ok(None)
|
||||
}
|
||||
}
|
||||
_ => Ok(None)
|
||||
}
|
||||
None => Ok(None)
|
||||
}
|
||||
}
|
||||
|
||||
fn lrange(&self, key: &str, start: i64, stop: i64) -> Result<Vec<String>, DBError> {
|
||||
match self.get_storage_value(key)? {
|
||||
Some(storage_val) => match storage_val.value {
|
||||
ValueType::List(list) => {
|
||||
if list.is_empty() {
|
||||
return Ok(Vec::new());
|
||||
}
|
||||
|
||||
let len = list.len() as i64;
|
||||
let start_idx = if start < 0 {
|
||||
std::cmp::max(0, len + start)
|
||||
} else {
|
||||
std::cmp::min(start, len)
|
||||
};
|
||||
let stop_idx = if stop < 0 {
|
||||
std::cmp::max(-1, len + stop)
|
||||
} else {
|
||||
std::cmp::min(stop, len - 1)
|
||||
};
|
||||
|
||||
if start_idx > stop_idx || start_idx >= len {
|
||||
return Ok(Vec::new());
|
||||
}
|
||||
|
||||
let start_usize = start_idx as usize;
|
||||
let stop_usize = (stop_idx + 1) as usize;
|
||||
|
||||
Ok(list[start_usize..std::cmp::min(stop_usize, list.len())].to_vec())
|
||||
}
|
||||
_ => Ok(Vec::new())
|
||||
}
|
||||
None => Ok(Vec::new())
|
||||
}
|
||||
}
|
||||
|
||||
fn ltrim(&self, key: &str, start: i64, stop: i64) -> Result<(), DBError> {
|
||||
let mut storage_val = match self.get_storage_value(key)? {
|
||||
Some(sv) => sv,
|
||||
None => return Ok(())
|
||||
};
|
||||
|
||||
let list = match &mut storage_val.value {
|
||||
ValueType::List(l) => l,
|
||||
_ => return Ok(())
|
||||
};
|
||||
|
||||
if list.is_empty() {
|
||||
return Ok(());
|
||||
}
|
||||
|
||||
let len = list.len() as i64;
|
||||
let start_idx = if start < 0 {
|
||||
std::cmp::max(0, len + start)
|
||||
} else {
|
||||
std::cmp::min(start, len)
|
||||
};
|
||||
let stop_idx = if stop < 0 {
|
||||
std::cmp::max(-1, len + stop)
|
||||
} else {
|
||||
std::cmp::min(stop, len - 1)
|
||||
};
|
||||
|
||||
if start_idx > stop_idx || start_idx >= len {
|
||||
self.del(key.to_string())?;
|
||||
} else {
|
||||
let start_usize = start_idx as usize;
|
||||
let stop_usize = (stop_idx + 1) as usize;
|
||||
*list = list[start_usize..std::cmp::min(stop_usize, list.len())].to_vec();
|
||||
|
||||
if list.is_empty() {
|
||||
self.del(key.to_string())?;
|
||||
} else {
|
||||
self.set_storage_value(key, storage_val)?;
|
||||
self.db.flush().map_err(|e| DBError(e.to_string()))?;
|
||||
}
|
||||
}
|
||||
|
||||
Ok(())
|
||||
}
|
||||
|
||||
fn lrem(&self, key: &str, count: i64, element: &str) -> Result<i64, DBError> {
|
||||
let mut storage_val = match self.get_storage_value(key)? {
|
||||
Some(sv) => sv,
|
||||
None => return Ok(0)
|
||||
};
|
||||
|
||||
let list = match &mut storage_val.value {
|
||||
ValueType::List(l) => l,
|
||||
_ => return Ok(0)
|
||||
};
|
||||
|
||||
let mut removed = 0i64;
|
||||
|
||||
if count == 0 {
|
||||
// Remove all occurrences
|
||||
let original_len = list.len();
|
||||
list.retain(|x| x != element);
|
||||
removed = (original_len - list.len()) as i64;
|
||||
} else if count > 0 {
|
||||
// Remove first count occurrences
|
||||
let mut to_remove = count as usize;
|
||||
list.retain(|x| {
|
||||
if x == element && to_remove > 0 {
|
||||
to_remove -= 1;
|
||||
removed += 1;
|
||||
false
|
||||
} else {
|
||||
true
|
||||
}
|
||||
});
|
||||
} else {
|
||||
// Remove last |count| occurrences
|
||||
let mut to_remove = (-count) as usize;
|
||||
for i in (0..list.len()).rev() {
|
||||
if list[i] == element && to_remove > 0 {
|
||||
list.remove(i);
|
||||
to_remove -= 1;
|
||||
removed += 1;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
if list.is_empty() {
|
||||
self.del(key.to_string())?;
|
||||
} else {
|
||||
self.set_storage_value(key, storage_val)?;
|
||||
self.db.flush().map_err(|e| DBError(e.to_string()))?;
|
||||
}
|
||||
|
||||
Ok(removed)
|
||||
}
|
||||
|
||||
// Expiration
|
||||
fn ttl(&self, key: &str) -> Result<i64, DBError> {
|
||||
match self.get_storage_value(key)? {
|
||||
Some(storage_val) => {
|
||||
if let Some(expires_at) = storage_val.expires_at {
|
||||
let now = Self::now_millis();
|
||||
if now >= expires_at {
|
||||
Ok(-2) // Key has expired
|
||||
} else {
|
||||
Ok(((expires_at - now) / 1000) as i64) // TTL in seconds
|
||||
}
|
||||
} else {
|
||||
Ok(-1) // Key exists but has no expiration
|
||||
}
|
||||
}
|
||||
None => Ok(-2) // Key does not exist
|
||||
}
|
||||
}
|
||||
|
||||
fn expire_seconds(&self, key: &str, secs: u64) -> Result<bool, DBError> {
|
||||
let mut storage_val = match self.get_storage_value(key)? {
|
||||
Some(sv) => sv,
|
||||
None => return Ok(false)
|
||||
};
|
||||
|
||||
storage_val.expires_at = Some(Self::now_millis() + (secs as u128) * 1000);
|
||||
self.set_storage_value(key, storage_val)?;
|
||||
self.db.flush().map_err(|e| DBError(e.to_string()))?;
|
||||
Ok(true)
|
||||
}
|
||||
|
||||
fn pexpire_millis(&self, key: &str, ms: u128) -> Result<bool, DBError> {
|
||||
let mut storage_val = match self.get_storage_value(key)? {
|
||||
Some(sv) => sv,
|
||||
None => return Ok(false)
|
||||
};
|
||||
|
||||
storage_val.expires_at = Some(Self::now_millis() + ms);
|
||||
self.set_storage_value(key, storage_val)?;
|
||||
self.db.flush().map_err(|e| DBError(e.to_string()))?;
|
||||
Ok(true)
|
||||
}
|
||||
|
||||
fn persist(&self, key: &str) -> Result<bool, DBError> {
|
||||
let mut storage_val = match self.get_storage_value(key)? {
|
||||
Some(sv) => sv,
|
||||
None => return Ok(false)
|
||||
};
|
||||
|
||||
if storage_val.expires_at.is_some() {
|
||||
storage_val.expires_at = None;
|
||||
self.set_storage_value(key, storage_val)?;
|
||||
self.db.flush().map_err(|e| DBError(e.to_string()))?;
|
||||
Ok(true)
|
||||
} else {
|
||||
Ok(false)
|
||||
}
|
||||
}
|
||||
|
||||
fn expire_at_seconds(&self, key: &str, ts_secs: i64) -> Result<bool, DBError> {
|
||||
let mut storage_val = match self.get_storage_value(key)? {
|
||||
Some(sv) => sv,
|
||||
None => return Ok(false)
|
||||
};
|
||||
|
||||
let expires_at_ms: u128 = if ts_secs <= 0 { 0 } else { (ts_secs as u128) * 1000 };
|
||||
storage_val.expires_at = Some(expires_at_ms);
|
||||
self.set_storage_value(key, storage_val)?;
|
||||
self.db.flush().map_err(|e| DBError(e.to_string()))?;
|
||||
Ok(true)
|
||||
}
|
||||
|
||||
fn pexpire_at_millis(&self, key: &str, ts_ms: i64) -> Result<bool, DBError> {
|
||||
let mut storage_val = match self.get_storage_value(key)? {
|
||||
Some(sv) => sv,
|
||||
None => return Ok(false)
|
||||
};
|
||||
|
||||
let expires_at_ms: u128 = if ts_ms <= 0 { 0 } else { ts_ms as u128 };
|
||||
storage_val.expires_at = Some(expires_at_ms);
|
||||
self.set_storage_value(key, storage_val)?;
|
||||
self.db.flush().map_err(|e| DBError(e.to_string()))?;
|
||||
Ok(true)
|
||||
}
|
||||
|
||||
fn is_encrypted(&self) -> bool {
|
||||
self.crypto.is_some()
|
||||
}
|
||||
|
||||
fn info(&self) -> Result<Vec<(String, String)>, DBError> {
|
||||
let dbsize = self.dbsize()?;
|
||||
Ok(vec![
|
||||
("db_size".to_string(), dbsize.to_string()),
|
||||
("is_encrypted".to_string(), self.is_encrypted().to_string()),
|
||||
])
|
||||
}
|
||||
|
||||
fn clone_arc(&self) -> Arc<dyn StorageBackend> {
|
||||
// Note: This is a simplified clone - in production you might want to
|
||||
// handle this differently as sled::Db is already Arc internally
|
||||
Arc::new(SledStorage {
|
||||
db: self.db.clone(),
|
||||
types: self.types.clone(),
|
||||
crypto: self.crypto.clone(),
|
||||
})
|
||||
}
|
||||
}
|
58
src/storage_trait.rs
Normal file
58
src/storage_trait.rs
Normal file
@@ -0,0 +1,58 @@
|
||||
// src/storage_trait.rs
|
||||
use crate::error::DBError;
|
||||
use std::sync::Arc;
|
||||
|
||||
pub trait StorageBackend: Send + Sync {
|
||||
// Basic key operations
|
||||
fn get(&self, key: &str) -> Result<Option<String>, DBError>;
|
||||
fn set(&self, key: String, value: String) -> Result<(), DBError>;
|
||||
fn setx(&self, key: String, value: String, expire_ms: u128) -> Result<(), DBError>;
|
||||
fn del(&self, key: String) -> Result<(), DBError>;
|
||||
fn exists(&self, key: &str) -> Result<bool, DBError>;
|
||||
fn keys(&self, pattern: &str) -> Result<Vec<String>, DBError>;
|
||||
fn dbsize(&self) -> Result<i64, DBError>;
|
||||
fn flushdb(&self) -> Result<(), DBError>;
|
||||
fn get_key_type(&self, key: &str) -> Result<Option<String>, DBError>;
|
||||
|
||||
// Scanning
|
||||
fn scan(&self, cursor: u64, pattern: Option<&str>, count: Option<u64>) -> Result<(u64, Vec<(String, String)>), DBError>;
|
||||
fn hscan(&self, key: &str, cursor: u64, pattern: Option<&str>, count: Option<u64>) -> Result<(u64, Vec<(String, String)>), DBError>;
|
||||
|
||||
// Hash operations
|
||||
fn hset(&self, key: &str, pairs: Vec<(String, String)>) -> Result<i64, DBError>;
|
||||
fn hget(&self, key: &str, field: &str) -> Result<Option<String>, DBError>;
|
||||
fn hgetall(&self, key: &str) -> Result<Vec<(String, String)>, DBError>;
|
||||
fn hdel(&self, key: &str, fields: Vec<String>) -> Result<i64, DBError>;
|
||||
fn hexists(&self, key: &str, field: &str) -> Result<bool, DBError>;
|
||||
fn hkeys(&self, key: &str) -> Result<Vec<String>, DBError>;
|
||||
fn hvals(&self, key: &str) -> Result<Vec<String>, DBError>;
|
||||
fn hlen(&self, key: &str) -> Result<i64, DBError>;
|
||||
fn hmget(&self, key: &str, fields: Vec<String>) -> Result<Vec<Option<String>>, DBError>;
|
||||
fn hsetnx(&self, key: &str, field: &str, value: &str) -> Result<bool, DBError>;
|
||||
|
||||
// List operations
|
||||
fn lpush(&self, key: &str, elements: Vec<String>) -> Result<i64, DBError>;
|
||||
fn rpush(&self, key: &str, elements: Vec<String>) -> Result<i64, DBError>;
|
||||
fn lpop(&self, key: &str, count: u64) -> Result<Vec<String>, DBError>;
|
||||
fn rpop(&self, key: &str, count: u64) -> Result<Vec<String>, DBError>;
|
||||
fn llen(&self, key: &str) -> Result<i64, DBError>;
|
||||
fn lindex(&self, key: &str, index: i64) -> Result<Option<String>, DBError>;
|
||||
fn lrange(&self, key: &str, start: i64, stop: i64) -> Result<Vec<String>, DBError>;
|
||||
fn ltrim(&self, key: &str, start: i64, stop: i64) -> Result<(), DBError>;
|
||||
fn lrem(&self, key: &str, count: i64, element: &str) -> Result<i64, DBError>;
|
||||
|
||||
// Expiration
|
||||
fn ttl(&self, key: &str) -> Result<i64, DBError>;
|
||||
fn expire_seconds(&self, key: &str, secs: u64) -> Result<bool, DBError>;
|
||||
fn pexpire_millis(&self, key: &str, ms: u128) -> Result<bool, DBError>;
|
||||
fn persist(&self, key: &str) -> Result<bool, DBError>;
|
||||
fn expire_at_seconds(&self, key: &str, ts_secs: i64) -> Result<bool, DBError>;
|
||||
fn pexpire_at_millis(&self, key: &str, ts_ms: i64) -> Result<bool, DBError>;
|
||||
|
||||
// Metadata
|
||||
fn is_encrypted(&self) -> bool;
|
||||
fn info(&self) -> Result<Vec<(String, String)>, DBError>;
|
||||
|
||||
// Clone to Arc for sharing
|
||||
fn clone_arc(&self) -> Arc<dyn StorageBackend>;
|
||||
}
|
567
src/tantivy_search.rs
Normal file
567
src/tantivy_search.rs
Normal file
@@ -0,0 +1,567 @@
|
||||
use tantivy::{
|
||||
collector::TopDocs,
|
||||
directory::MmapDirectory,
|
||||
query::{QueryParser, BooleanQuery, Query, TermQuery, Occur},
|
||||
schema::{Schema, Field, TextOptions, TextFieldIndexing,
|
||||
STORED, STRING, Value},
|
||||
Index, IndexWriter, IndexReader, ReloadPolicy,
|
||||
Term, DateTime, TantivyDocument,
|
||||
tokenizer::{TokenizerManager},
|
||||
};
|
||||
use std::path::PathBuf;
|
||||
use std::sync::{Arc, RwLock};
|
||||
use std::collections::HashMap;
|
||||
use crate::error::DBError;
|
||||
use serde::{Serialize, Deserialize};
|
||||
|
||||
#[derive(Debug, Clone, Serialize, Deserialize)]
|
||||
pub enum FieldDef {
|
||||
Text {
|
||||
stored: bool,
|
||||
indexed: bool,
|
||||
tokenized: bool,
|
||||
fast: bool,
|
||||
},
|
||||
Numeric {
|
||||
stored: bool,
|
||||
indexed: bool,
|
||||
fast: bool,
|
||||
precision: NumericType,
|
||||
},
|
||||
Tag {
|
||||
stored: bool,
|
||||
separator: String,
|
||||
case_sensitive: bool,
|
||||
},
|
||||
Geo {
|
||||
stored: bool,
|
||||
},
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone, Serialize, Deserialize)]
|
||||
pub enum NumericType {
|
||||
I64,
|
||||
U64,
|
||||
F64,
|
||||
Date,
|
||||
}
|
||||
|
||||
pub struct IndexSchema {
|
||||
schema: Schema,
|
||||
fields: HashMap<String, (Field, FieldDef)>,
|
||||
default_search_fields: Vec<Field>,
|
||||
}
|
||||
|
||||
pub struct TantivySearch {
|
||||
index: Index,
|
||||
writer: Arc<RwLock<IndexWriter>>,
|
||||
reader: IndexReader,
|
||||
index_schema: IndexSchema,
|
||||
name: String,
|
||||
config: IndexConfig,
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone, Serialize, Deserialize)]
|
||||
pub struct IndexConfig {
|
||||
pub language: String,
|
||||
pub stopwords: Vec<String>,
|
||||
pub stemming: bool,
|
||||
pub max_doc_count: Option<usize>,
|
||||
pub default_score: f64,
|
||||
}
|
||||
|
||||
impl Default for IndexConfig {
|
||||
fn default() -> Self {
|
||||
IndexConfig {
|
||||
language: "english".to_string(),
|
||||
stopwords: vec![],
|
||||
stemming: true,
|
||||
max_doc_count: None,
|
||||
default_score: 1.0,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl TantivySearch {
|
||||
pub fn new_with_schema(
|
||||
base_path: PathBuf,
|
||||
name: String,
|
||||
field_definitions: Vec<(String, FieldDef)>,
|
||||
config: Option<IndexConfig>,
|
||||
) -> Result<Self, DBError> {
|
||||
let index_path = base_path.join(&name);
|
||||
std::fs::create_dir_all(&index_path)
|
||||
.map_err(|e| DBError(format!("Failed to create index dir: {}", e)))?;
|
||||
|
||||
// Build schema from field definitions
|
||||
let mut schema_builder = Schema::builder();
|
||||
let mut fields = HashMap::new();
|
||||
let mut default_search_fields = Vec::new();
|
||||
|
||||
// Always add a document ID field
|
||||
let id_field = schema_builder.add_text_field("_id", STRING | STORED);
|
||||
fields.insert("_id".to_string(), (id_field, FieldDef::Text {
|
||||
stored: true,
|
||||
indexed: true,
|
||||
tokenized: false,
|
||||
fast: false,
|
||||
}));
|
||||
|
||||
// Add user-defined fields
|
||||
for (field_name, field_def) in field_definitions {
|
||||
let field = match &field_def {
|
||||
FieldDef::Text { stored, indexed, tokenized, fast: _fast } => {
|
||||
let mut text_options = TextOptions::default();
|
||||
|
||||
if *stored {
|
||||
text_options = text_options.set_stored();
|
||||
}
|
||||
|
||||
if *indexed {
|
||||
let indexing_options = if *tokenized {
|
||||
TextFieldIndexing::default()
|
||||
.set_tokenizer("default")
|
||||
.set_index_option(tantivy::schema::IndexRecordOption::WithFreqsAndPositions)
|
||||
} else {
|
||||
TextFieldIndexing::default()
|
||||
.set_tokenizer("raw")
|
||||
.set_index_option(tantivy::schema::IndexRecordOption::Basic)
|
||||
};
|
||||
text_options = text_options.set_indexing_options(indexing_options);
|
||||
|
||||
let f = schema_builder.add_text_field(&field_name, text_options);
|
||||
if *tokenized {
|
||||
default_search_fields.push(f);
|
||||
}
|
||||
f
|
||||
} else {
|
||||
schema_builder.add_text_field(&field_name, text_options)
|
||||
}
|
||||
}
|
||||
FieldDef::Numeric { stored, indexed, fast, precision } => {
|
||||
match precision {
|
||||
NumericType::I64 => {
|
||||
let mut opts = tantivy::schema::NumericOptions::default();
|
||||
if *stored { opts = opts.set_stored(); }
|
||||
if *indexed { opts = opts.set_indexed(); }
|
||||
if *fast { opts = opts.set_fast(); }
|
||||
schema_builder.add_i64_field(&field_name, opts)
|
||||
}
|
||||
NumericType::U64 => {
|
||||
let mut opts = tantivy::schema::NumericOptions::default();
|
||||
if *stored { opts = opts.set_stored(); }
|
||||
if *indexed { opts = opts.set_indexed(); }
|
||||
if *fast { opts = opts.set_fast(); }
|
||||
schema_builder.add_u64_field(&field_name, opts)
|
||||
}
|
||||
NumericType::F64 => {
|
||||
let mut opts = tantivy::schema::NumericOptions::default();
|
||||
if *stored { opts = opts.set_stored(); }
|
||||
if *indexed { opts = opts.set_indexed(); }
|
||||
if *fast { opts = opts.set_fast(); }
|
||||
schema_builder.add_f64_field(&field_name, opts)
|
||||
}
|
||||
NumericType::Date => {
|
||||
let mut opts = tantivy::schema::DateOptions::default();
|
||||
if *stored { opts = opts.set_stored(); }
|
||||
if *indexed { opts = opts.set_indexed(); }
|
||||
if *fast { opts = opts.set_fast(); }
|
||||
schema_builder.add_date_field(&field_name, opts)
|
||||
}
|
||||
}
|
||||
}
|
||||
FieldDef::Tag { stored, separator: _, case_sensitive: _ } => {
|
||||
let mut text_options = TextOptions::default();
|
||||
if *stored {
|
||||
text_options = text_options.set_stored();
|
||||
}
|
||||
text_options = text_options.set_indexing_options(
|
||||
TextFieldIndexing::default()
|
||||
.set_tokenizer("raw")
|
||||
.set_index_option(tantivy::schema::IndexRecordOption::Basic)
|
||||
);
|
||||
schema_builder.add_text_field(&field_name, text_options)
|
||||
}
|
||||
FieldDef::Geo { stored } => {
|
||||
// For now, store as two f64 fields for lat/lon
|
||||
let mut opts = tantivy::schema::NumericOptions::default();
|
||||
if *stored { opts = opts.set_stored(); }
|
||||
opts = opts.set_indexed().set_fast();
|
||||
|
||||
let lat_field = schema_builder.add_f64_field(&format!("{}_lat", field_name), opts.clone());
|
||||
let lon_field = schema_builder.add_f64_field(&format!("{}_lon", field_name), opts);
|
||||
|
||||
fields.insert(format!("{}_lat", field_name), (lat_field, FieldDef::Numeric {
|
||||
stored: *stored,
|
||||
indexed: true,
|
||||
fast: true,
|
||||
precision: NumericType::F64,
|
||||
}));
|
||||
fields.insert(format!("{}_lon", field_name), (lon_field, FieldDef::Numeric {
|
||||
stored: *stored,
|
||||
indexed: true,
|
||||
fast: true,
|
||||
precision: NumericType::F64,
|
||||
}));
|
||||
continue; // Skip adding the geo field itself
|
||||
}
|
||||
};
|
||||
|
||||
fields.insert(field_name.clone(), (field, field_def));
|
||||
}
|
||||
|
||||
let schema = schema_builder.build();
|
||||
let index_schema = IndexSchema {
|
||||
schema: schema.clone(),
|
||||
fields,
|
||||
default_search_fields,
|
||||
};
|
||||
|
||||
// Create or open index
|
||||
let dir = MmapDirectory::open(&index_path)
|
||||
.map_err(|e| DBError(format!("Failed to open index directory: {}", e)))?;
|
||||
|
||||
let mut index = Index::open_or_create(dir, schema)
|
||||
.map_err(|e| DBError(format!("Failed to create index: {}", e)))?;
|
||||
|
||||
// Configure tokenizers
|
||||
let tokenizer_manager = TokenizerManager::default();
|
||||
index.set_tokenizers(tokenizer_manager);
|
||||
|
||||
let writer = index.writer(1_000_000)
|
||||
.map_err(|e| DBError(format!("Failed to create index writer: {}", e)))?;
|
||||
|
||||
let reader = index
|
||||
.reader_builder()
|
||||
.reload_policy(ReloadPolicy::OnCommitWithDelay)
|
||||
.try_into()
|
||||
.map_err(|e| DBError(format!("Failed to create reader: {}", e)))?;
|
||||
|
||||
let config = config.unwrap_or_default();
|
||||
|
||||
Ok(TantivySearch {
|
||||
index,
|
||||
writer: Arc::new(RwLock::new(writer)),
|
||||
reader,
|
||||
index_schema,
|
||||
name,
|
||||
config,
|
||||
})
|
||||
}
|
||||
|
||||
pub fn add_document_with_fields(
|
||||
&self,
|
||||
doc_id: &str,
|
||||
fields: HashMap<String, String>,
|
||||
) -> Result<(), DBError> {
|
||||
let mut writer = self.writer.write()
|
||||
.map_err(|e| DBError(format!("Failed to acquire writer lock: {}", e)))?;
|
||||
|
||||
// Delete existing document with same ID
|
||||
if let Some((id_field, _)) = self.index_schema.fields.get("_id") {
|
||||
writer.delete_term(Term::from_field_text(*id_field, doc_id));
|
||||
}
|
||||
|
||||
// Create new document
|
||||
let mut doc = tantivy::doc!();
|
||||
|
||||
// Add document ID
|
||||
if let Some((id_field, _)) = self.index_schema.fields.get("_id") {
|
||||
doc.add_text(*id_field, doc_id);
|
||||
}
|
||||
|
||||
// Add other fields based on schema
|
||||
for (field_name, field_value) in fields {
|
||||
if let Some((field, field_def)) = self.index_schema.fields.get(&field_name) {
|
||||
match field_def {
|
||||
FieldDef::Text { .. } => {
|
||||
doc.add_text(*field, &field_value);
|
||||
}
|
||||
FieldDef::Numeric { precision, .. } => {
|
||||
match precision {
|
||||
NumericType::I64 => {
|
||||
if let Ok(v) = field_value.parse::<i64>() {
|
||||
doc.add_i64(*field, v);
|
||||
}
|
||||
}
|
||||
NumericType::U64 => {
|
||||
if let Ok(v) = field_value.parse::<u64>() {
|
||||
doc.add_u64(*field, v);
|
||||
}
|
||||
}
|
||||
NumericType::F64 => {
|
||||
if let Ok(v) = field_value.parse::<f64>() {
|
||||
doc.add_f64(*field, v);
|
||||
}
|
||||
}
|
||||
NumericType::Date => {
|
||||
if let Ok(v) = field_value.parse::<i64>() {
|
||||
doc.add_date(*field, DateTime::from_timestamp_millis(v));
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
FieldDef::Tag { separator, case_sensitive, .. } => {
|
||||
let tags = if !case_sensitive {
|
||||
field_value.to_lowercase()
|
||||
} else {
|
||||
field_value.clone()
|
||||
};
|
||||
|
||||
// Store tags as separate terms for efficient filtering
|
||||
for tag in tags.split(separator.as_str()) {
|
||||
doc.add_text(*field, tag.trim());
|
||||
}
|
||||
}
|
||||
FieldDef::Geo { .. } => {
|
||||
// Parse "lat,lon" format
|
||||
let parts: Vec<&str> = field_value.split(',').collect();
|
||||
if parts.len() == 2 {
|
||||
if let (Ok(lat), Ok(lon)) = (parts[0].parse::<f64>(), parts[1].parse::<f64>()) {
|
||||
if let Some((lat_field, _)) = self.index_schema.fields.get(&format!("{}_lat", field_name)) {
|
||||
doc.add_f64(*lat_field, lat);
|
||||
}
|
||||
if let Some((lon_field, _)) = self.index_schema.fields.get(&format!("{}_lon", field_name)) {
|
||||
doc.add_f64(*lon_field, lon);
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
writer.add_document(doc).map_err(|e| DBError(format!("Failed to add document: {}", e)))?;
|
||||
|
||||
writer.commit()
|
||||
.map_err(|e| DBError(format!("Failed to commit: {}", e)))?;
|
||||
|
||||
Ok(())
|
||||
}
|
||||
|
||||
pub fn search_with_options(
|
||||
&self,
|
||||
query_str: &str,
|
||||
options: SearchOptions,
|
||||
) -> Result<SearchResults, DBError> {
|
||||
let searcher = self.reader.searcher();
|
||||
|
||||
// Parse query based on search fields
|
||||
let query: Box<dyn Query> = if self.index_schema.default_search_fields.is_empty() {
|
||||
return Err(DBError("No searchable fields defined in schema".to_string()));
|
||||
} else {
|
||||
let query_parser = QueryParser::for_index(
|
||||
&self.index,
|
||||
self.index_schema.default_search_fields.clone(),
|
||||
);
|
||||
|
||||
Box::new(query_parser.parse_query(query_str)
|
||||
.map_err(|e| DBError(format!("Failed to parse query: {}", e)))?)
|
||||
};
|
||||
|
||||
// Apply filters if any
|
||||
let final_query = if !options.filters.is_empty() {
|
||||
let mut clauses: Vec<(Occur, Box<dyn Query>)> = vec![(Occur::Must, query)];
|
||||
|
||||
// Add filters
|
||||
for filter in options.filters {
|
||||
if let Some((field, _)) = self.index_schema.fields.get(&filter.field) {
|
||||
match filter.filter_type {
|
||||
FilterType::Equals(value) => {
|
||||
let term_query = TermQuery::new(
|
||||
Term::from_field_text(*field, &value),
|
||||
tantivy::schema::IndexRecordOption::Basic,
|
||||
);
|
||||
clauses.push((Occur::Must, Box::new(term_query)));
|
||||
}
|
||||
FilterType::Range { min: _, max: _ } => {
|
||||
// Would need numeric field handling here
|
||||
// Simplified for now
|
||||
}
|
||||
FilterType::InSet(values) => {
|
||||
let mut sub_clauses: Vec<(Occur, Box<dyn Query>)> = vec![];
|
||||
for value in values {
|
||||
let term_query = TermQuery::new(
|
||||
Term::from_field_text(*field, &value),
|
||||
tantivy::schema::IndexRecordOption::Basic,
|
||||
);
|
||||
sub_clauses.push((Occur::Should, Box::new(term_query)));
|
||||
}
|
||||
clauses.push((Occur::Must, Box::new(BooleanQuery::new(sub_clauses))));
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
Box::new(BooleanQuery::new(clauses))
|
||||
} else {
|
||||
query
|
||||
};
|
||||
|
||||
// Execute search
|
||||
let top_docs = searcher.search(
|
||||
&*final_query,
|
||||
&TopDocs::with_limit(options.limit + options.offset)
|
||||
).map_err(|e| DBError(format!("Search failed: {}", e)))?;
|
||||
|
||||
let total_hits = top_docs.len();
|
||||
let mut documents = Vec::new();
|
||||
|
||||
for (score, doc_address) in top_docs.iter().skip(options.offset).take(options.limit) {
|
||||
let retrieved_doc: TantivyDocument = searcher.doc(*doc_address)
|
||||
.map_err(|e| DBError(format!("Failed to retrieve doc: {}", e)))?;
|
||||
|
||||
let mut doc_fields = HashMap::new();
|
||||
|
||||
// Extract all stored fields
|
||||
for (field_name, (field, field_def)) in &self.index_schema.fields {
|
||||
match field_def {
|
||||
FieldDef::Text { stored, .. } |
|
||||
FieldDef::Tag { stored, .. } => {
|
||||
if *stored {
|
||||
if let Some(value) = retrieved_doc.get_first(*field) {
|
||||
if let Some(text) = value.as_str() {
|
||||
doc_fields.insert(field_name.clone(), text.to_string());
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
FieldDef::Numeric { stored, precision, .. } => {
|
||||
if *stored {
|
||||
let value_str = match precision {
|
||||
NumericType::I64 => {
|
||||
retrieved_doc.get_first(*field)
|
||||
.and_then(|v| v.as_i64())
|
||||
.map(|v| v.to_string())
|
||||
}
|
||||
NumericType::U64 => {
|
||||
retrieved_doc.get_first(*field)
|
||||
.and_then(|v| v.as_u64())
|
||||
.map(|v| v.to_string())
|
||||
}
|
||||
NumericType::F64 => {
|
||||
retrieved_doc.get_first(*field)
|
||||
.and_then(|v| v.as_f64())
|
||||
.map(|v| v.to_string())
|
||||
}
|
||||
NumericType::Date => {
|
||||
retrieved_doc.get_first(*field)
|
||||
.and_then(|v| v.as_datetime())
|
||||
.map(|v| v.into_timestamp_millis().to_string())
|
||||
}
|
||||
};
|
||||
|
||||
if let Some(v) = value_str {
|
||||
doc_fields.insert(field_name.clone(), v);
|
||||
}
|
||||
}
|
||||
}
|
||||
FieldDef::Geo { stored } => {
|
||||
if *stored {
|
||||
let lat_field = self.index_schema.fields.get(&format!("{}_lat", field_name)).unwrap().0;
|
||||
let lon_field = self.index_schema.fields.get(&format!("{}_lon", field_name)).unwrap().0;
|
||||
|
||||
let lat = retrieved_doc.get_first(lat_field).and_then(|v| v.as_f64());
|
||||
let lon = retrieved_doc.get_first(lon_field).and_then(|v| v.as_f64());
|
||||
|
||||
if let (Some(lat), Some(lon)) = (lat, lon) {
|
||||
doc_fields.insert(field_name.clone(), format!("{},{}", lat, lon));
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
documents.push(SearchDocument {
|
||||
fields: doc_fields,
|
||||
score: *score,
|
||||
});
|
||||
}
|
||||
|
||||
Ok(SearchResults {
|
||||
total: total_hits,
|
||||
documents,
|
||||
})
|
||||
}
|
||||
|
||||
pub fn get_info(&self) -> Result<IndexInfo, DBError> {
|
||||
let searcher = self.reader.searcher();
|
||||
let num_docs = searcher.num_docs();
|
||||
|
||||
let fields_info: Vec<FieldInfo> = self.index_schema.fields.iter().map(|(name, (_, def))| {
|
||||
FieldInfo {
|
||||
name: name.clone(),
|
||||
field_type: format!("{:?}", def),
|
||||
}
|
||||
}).collect();
|
||||
|
||||
Ok(IndexInfo {
|
||||
name: self.name.clone(),
|
||||
num_docs,
|
||||
fields: fields_info,
|
||||
config: self.config.clone(),
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone)]
|
||||
pub struct SearchOptions {
|
||||
pub limit: usize,
|
||||
pub offset: usize,
|
||||
pub filters: Vec<Filter>,
|
||||
pub sort_by: Option<String>,
|
||||
pub return_fields: Option<Vec<String>>,
|
||||
pub highlight: bool,
|
||||
}
|
||||
|
||||
impl Default for SearchOptions {
|
||||
fn default() -> Self {
|
||||
SearchOptions {
|
||||
limit: 10,
|
||||
offset: 0,
|
||||
filters: vec![],
|
||||
sort_by: None,
|
||||
return_fields: None,
|
||||
highlight: false,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone)]
|
||||
pub struct Filter {
|
||||
pub field: String,
|
||||
pub filter_type: FilterType,
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone)]
|
||||
pub enum FilterType {
|
||||
Equals(String),
|
||||
Range { min: String, max: String },
|
||||
InSet(Vec<String>),
|
||||
}
|
||||
|
||||
#[derive(Debug)]
|
||||
pub struct SearchResults {
|
||||
pub total: usize,
|
||||
pub documents: Vec<SearchDocument>,
|
||||
}
|
||||
|
||||
#[derive(Debug)]
|
||||
pub struct SearchDocument {
|
||||
pub fields: HashMap<String, String>,
|
||||
pub score: f32,
|
||||
}
|
||||
|
||||
#[derive(Debug, Serialize, Deserialize)]
|
||||
pub struct IndexInfo {
|
||||
pub name: String,
|
||||
pub num_docs: u64,
|
||||
pub fields: Vec<FieldInfo>,
|
||||
pub config: IndexConfig,
|
||||
}
|
||||
|
||||
#[derive(Debug, Serialize, Deserialize)]
|
||||
pub struct FieldInfo {
|
||||
pub name: String,
|
||||
pub field_type: String,
|
||||
}
|
Reference in New Issue
Block a user