21 Commits

Author SHA1 Message Date
ee163bb6bf ... 2025-08-16 15:16:15 +02:00
84611dd245 ... 2025-08-16 15:10:55 +02:00
200d0c928d ... 2025-08-16 14:22:56 +02:00
30a09e6d53 ... 2025-08-16 13:58:40 +02:00
542996a0ff ... 2025-08-16 13:33:56 +02:00
63ab39b4b1 ... 2025-08-16 11:22:01 +02:00
ee94d731d7 ... 2025-08-16 11:09:18 +02:00
c7945624bd ... 2025-08-16 10:53:48 +02:00
f8dd304820 it works 2025-08-16 10:41:26 +02:00
5eab3b080c ... 2025-08-16 10:28:28 +02:00
246304b9fa ... 2025-08-16 10:10:24 +02:00
074be114c3 ... 2025-08-16 09:55:34 +02:00
e51af83e45 ... 2025-08-16 09:52:36 +02:00
dbd0635cd9 ... 2025-08-16 09:50:56 +02:00
0000d82799 ... 2025-08-16 09:29:18 +02:00
5502ff4bc5 ... 2025-08-16 09:06:33 +02:00
0511dddd99 ... 2025-08-16 08:50:28 +02:00
bec9b20ec7 ... 2025-08-16 08:41:19 +02:00
ad255a9f51 ... 2025-08-16 08:28:52 +02:00
7bcb673361 ... 2025-08-16 08:25:25 +02:00
0f6e595000 ... 2025-08-16 07:54:55 +02:00
50 changed files with 6342 additions and 2061 deletions

1746
Cargo.lock generated

File diff suppressed because it is too large Load Diff

View File

@@ -1,18 +1,38 @@
[package] [workspace]
name = "redis-rs" resolver = "2"
version = "0.0.1" members = [
authors = ["Pin Fang <fpfangpin@hotmail.com>"] "crates/herodb",
edition = "2021" "crates/libdbstorage",
"crates/libcrypto",
"crates/libcryptoa",
"crates/herocrypto",
"crates/supervisor",
"crates/supervisorrpc",
]
[dependencies] [workspace.dependencies]
anyhow = "1.0.59" # Common
bytes = "1.3.0" anyhow = "1.0"
thiserror = "1.0.32" tokio = { version = "1", features = ["full"] }
tokio = { version = "1.23.0", features = ["full"] }
clap = { version = "4.5.20", features = ["derive"] }
byteorder = "1.4.3"
futures = "0.3"
redb = "2.1.3"
serde = { version = "1.0", features = ["derive"] } serde = { version = "1.0", features = ["derive"] }
bincode = "1.3.3" serde_json = "1.0"
thiserror = "1.0"
log = "0.4"
bytes = "1.3"
# Crypto - Asymmetric
age = "0.10"
secrecy = "0.8"
ed25519-dalek = "2"
base64 = "0.22"
# Crypto - Symmetric & Utilities
chacha20poly1305 = "0.10"
rand = "0.8"
sha2 = "0.10"
# Database
redb = "2.1"
# CLI
clap = { version = "4.5", features = ["derive"] }

View File

@@ -1,64 +0,0 @@
title: 从 0 到 1 由 Rust 构建 Redis
description: 从 0 到 1 由 Rust 构建 Redis
theme: just-the-docs
url: https://fangpin.github.io/redis-rs
aux_links:
GitHub: https://fangpin.github.io/redis-rs
# logo: "/assets/images/just-the-docs.png"
search_enabled: true
search:
# Split pages into sections that can be searched individually
# Supports 1 - 6, default: 2
heading_level: 2
# Maximum amount of previews per search result
# Default: 3
previews: 3
# Maximum amount of words to display before a matched word in the preview
# Default: 5
preview_words_before: 5
# Maximum amount of words to display after a matched word in the preview
# Default: 10
preview_words_after: 10
# Set the search token separator
# Default: /[\s\-/]+/
# Example: enable support for hyphenated search words
tokenizer_separator: /[\s/]+/
# Display the relative url in search results
# Supports true (default) or false
rel_url: true
# Enable or disable the search button that appears in the bottom right corner of every page
# Supports true or false (default)
button: false
# Heading anchor links appear on hover over h1-h6 tags in page content
# allowing users to deep link to a particular heading on a page.
#
# Supports true (default) or false
heading_anchors: true
# Footer content
# appears at the bottom of every page's main content
# Note: The footer_content option is deprecated and will be removed in a future major release. Please use `_includes/footer_custom.html` for more robust markup / liquid-based content.
footer_content: "Copyright &copy; 2017-2024 Pin Fang"
# Footer last edited timestamp
last_edit_timestamp: true # show or hide edit time - page must have `last_modified_date` defined in the frontmatter
last_edit_time_format: "%b %e %Y at %I:%M %p" # uses ruby's time format: https://ruby-doc.org/stdlib-2.7.0/libdoc/time/rdoc/Time.html
# code
compress_html:
ignore:
envs: all
kramdown:
syntax_highlighter_opts:
block:
line_numbers: true

View File

@@ -0,0 +1,10 @@
[package]
name = "herocrypto"
version = "0.1.0"
edition = "2021"
[dependencies]
redis = { version = "0.24", features = ["tokio-comp"] }
thiserror = { workspace = true }
libcrypto = { path = "../libcrypto" }
libcryptoa = { path = "../libcryptoa" }

View File

@@ -0,0 +1,45 @@
// In crates/herocrypto/src/lib.rs
use redis::{Commands, RedisResult};
use thiserror::Error;
#[derive(Error, Debug)]
pub enum Error {
#[error("Redis connection error: {0}")]
Redis(#[from] redis::RedisError),
#[error("Asymmetric crypto error: {0}")]
Asymmetric(#[from] libcryptoa::AsymmetricCryptoError),
#[error("Key not found in database: {0}")]
KeyNotFound(String),
#[error("Command failed on server: {0}")]
CommandError(String),
}
pub struct HeroCrypto {
// e.g., using a connection manager from redis-rs
client: redis::Client,
}
impl HeroCrypto {
pub fn new(redis_url: &str) -> Result<Self, Error> {
Ok(Self { client: redis::Client::open(redis_url)? })
}
// --- High-level functions to be implemented ---
/// Generates a new keypair and stores it in HeroDB under the given name.
pub async fn generate_keypair(&self, name: &str) -> Result<(), Error> {
let mut con = self.client.get_async_connection().await?;
let (_pub, _priv): (String, String) = redis::cmd("AGE")
.arg("KEYGEN")
.arg(name)
.query_async(&mut con)
.await?;
Ok(())
}
/// Encrypts a message using a key stored in HeroDB.
pub async fn encrypt_by_name(&self, key_name: &str, plaintext: &str) -> Result<String, Error> {
// Implementation will call 'AGE ENCRYPTNAME ...'
unimplemented!()
}
}

31
crates/herodb/Cargo.toml Normal file
View File

@@ -0,0 +1,31 @@
[package]
name = "herodb"
version = "0.1.0"
edition = "2021"
authors = ["Pin Fang <fpfangpin@hotmail.com>"]
[[bin]]
name = "herodb"
path = "src/main.rs"
[dependencies]
# Workspace dependencies
anyhow = { workspace = true }
tokio = { workspace = true }
serde = { workspace = true }
log = { workspace = true }
clap = { workspace = true }
bytes = { workspace = true }
base64 = { workspace = true }
age = { workspace = true }
secrecy = { workspace = true }
ed25519-dalek = { workspace = true }
rand = { workspace = true }
# Local Crate Dependencies
libdbstorage = { path = "../libdbstorage" }
# We will create these libraries in the next steps
libcryptoa = { path = "../libcryptoa" }
[dev-dependencies]
redis = { version = "0.24", features = ["aio", "tokio-comp"] }

326
crates/herodb/age.rs Normal file
View File

@@ -0,0 +1,326 @@
//! age.rs — AGE (rage) helpers + persistent key management for your mini-Redis.
//
// Features:
// - X25519 encryption/decryption (age style)
// - Ed25519 detached signatures + verification
// - Persistent named keys in DB (strings):
// age:key:{name} -> X25519 recipient (public encryption key, "age1...")
// age:privkey:{name} -> X25519 identity (secret encryption key, "AGE-SECRET-KEY-1...")
// age:signpub:{name} -> Ed25519 verify pubkey (public, used to verify signatures)
// age:signpriv:{name} -> Ed25519 signing secret key (private, used to sign)
// - Base64 wrapping for ciphertext/signature binary blobs.
use crate::protocol::Protocol;
use crate::server::Server;
use libdbstorage::DBError;
use libcryptoa::AsymmetricCryptoError;
// ---------- Storage helpers ----------
fn sget(server: &Server, key: &str) -> Result<Option<String>, DBError> {
let st = server.current_storage()?;
st.get(key)
}
fn sset(server: &Server, key: &str, val: &str) -> Result<(), DBError> {
let st = server.current_storage()?;
st.set(key.to_string(), val.to_string())
}
fn enc_pub_key_key(name: &str) -> String { format!("age:key:{name}") }
fn enc_priv_key_key(name: &str) -> String { format!("age:privkey:{name}") }
fn sign_pub_key_key(name: &str) -> String { format!("age:signpub:{name}") }
fn sign_priv_key_key(name: &str) -> String { format!("age:signpriv:{name}") }
// ---------- Command handlers (RESP Protocol) ----------
// Basic (stateless) ones kept for completeness
pub async fn cmd_age_genenc() -> Protocol {
let (recip, ident) = libcryptoa::gen_enc_keypair();
Protocol::Array(vec![Protocol::BulkString(recip), Protocol::BulkString(ident)])
}
pub async fn cmd_age_gensign() -> Protocol {
let (verify, secret) = libcryptoa::gen_sign_keypair();
Protocol::Array(vec![Protocol::BulkString(verify), Protocol::BulkString(secret)])
}
pub async fn cmd_age_encrypt(recipient: &str, message: &str) -> Protocol {
match libcryptoa::encrypt_b64(recipient, message) {
Ok(b64) => Protocol::BulkString(b64),
Err(e) => Protocol::err(&format!("ERR age: {e}")),
}
}
pub async fn cmd_age_decrypt(identity: &str, ct_b64: &str) -> Protocol {
match libcryptoa::decrypt_b64(identity, ct_b64) {
Ok(pt) => Protocol::BulkString(pt),
Err(e) => Protocol::err(&format!("ERR age: {e}")),
}
}
pub async fn cmd_age_sign(secret: &str, message: &str) -> Protocol {
match libcryptoa::sign_b64(secret, message) {
Ok(b64sig) => Protocol::BulkString(b64sig),
Err(e) => Protocol::err(&format!("ERR age: {e}")),
}
}
pub async fn cmd_age_verify(verify_pub: &str, message: &str, sig_b64: &str) -> Protocol {
match libcryptoa::verify_b64(verify_pub, message, sig_b64) {
Ok(true) => Protocol::SimpleString("1".to_string()),
Ok(false) => Protocol::SimpleString("0".to_string()),
Err(e) => Protocol::err(&format!("ERR age: {e}")),
}
}
// ---------- NEW: Persistent, named-key commands ----------
pub async fn cmd_age_keygen(server: &Server, name: &str) -> Protocol {
let (recip, ident) = libcryptoa::gen_enc_keypair();
if let Err(e) = sset(server, &enc_pub_key_key(name), &recip) { return Protocol::err(&e.0); }
if let Err(e) = sset(server, &enc_priv_key_key(name), &ident) { return Protocol::err(&e.0); }
Protocol::Array(vec![Protocol::BulkString(recip), Protocol::BulkString(ident)])
}
pub async fn cmd_age_signkeygen(server: &Server, name: &str) -> Protocol {
let (verify, secret) = libcryptoa::gen_sign_keypair();
if let Err(e) = sset(server, &sign_pub_key_key(name), &verify) { return Protocol::err(&e.0); }
if let Err(e) = sset(server, &sign_priv_key_key(name), &secret) { return Protocol::err(&e.0); }
Protocol::Array(vec![Protocol::BulkString(verify), Protocol::BulkString(secret)])
}
pub async fn cmd_age_encrypt_name(server: &Server, name: &str, message: &str) -> Protocol {
let recip = match sget(server, &enc_pub_key_key(name)) {
Ok(Some(v)) => v,
Ok(None) => return Protocol::err(&format!("ERR age: missing recipient (age:key:{name})")),
Err(e) => return Protocol::err(&e.0),
};
match libcryptoa::encrypt_b64(&recip, message) {
Ok(ct) => Protocol::BulkString(ct),
Err(e) => Protocol::err(&format!("ERR age: {e}")),
}
}
pub async fn cmd_age_decrypt_name(server: &Server, name: &str, ct_b64: &str) -> Protocol {
let ident = match sget(server, &enc_priv_key_key(name)) {
Ok(Some(v)) => v,
Ok(None) => return Protocol::err(&format!("ERR age: missing identity (age:privkey:{name})")),
Err(e) => return Protocol::err(&e.0),
};
match libcryptoa::decrypt_b64(&ident, ct_b64) {
Ok(pt) => Protocol::BulkString(pt),
Err(e) => Protocol::err(&format!("ERR age: {e}")),
}
}
pub async fn cmd_age_sign_name(server: &Server, name: &str, message: &str) -> Protocol {
let sec = match sget(server, &sign_priv_key_key(name)) {
Ok(Some(v)) => v,
Ok(None) => return Protocol::err(&format!("ERR age: missing signing secret (age:signpriv:{name})")),
Err(e) => return Protocol::err(&e.0),
};
match libcryptoa::sign_b64(&sec, message) {
Ok(sig) => Protocol::BulkString(sig),
Err(e) => Protocol::err(&format!("ERR age: {e}")),
}
}
pub async fn cmd_age_verify_name(server: &Server, name: &str, message: &str, sig_b64: &str) -> Protocol {
let pubk = match sget(server, &sign_pub_key_key(name)) {
Ok(Some(v)) => v,
Ok(None) => return Protocol::err(&format!("ERR age: missing verify pubkey (age:signpub:{name})")),
Err(e) => return Protocol::err(&e.0),
};
match libcryptoa::verify_b64(&pubk, message, sig_b64) {
Ok(true) => Protocol::SimpleString("1".to_string()),
Ok(false) => Protocol::SimpleString("0".to_string()),
Err(e) => Protocol::err(&format!("ERR age: {e}")),
}
}
pub async fn cmd_age_list(server: &Server) -> Protocol {
// Returns 4 arrays: ["encpub", <names...>], ["encpriv", ...], ["signpub", ...], ["signpriv", ...]
let st = match server.current_storage() { Ok(s) => s, Err(e) => return Protocol::err(&e.0) };
let pull = |pat: &str, prefix: &str| -> Result<Vec<String>, DBError> {
let keys = st.keys(pat)?;
let mut names: Vec<String> = keys.into_iter()
.filter_map(|k| k.strip_prefix(prefix).map(|x| x.to_string()))
.collect();
names.sort();
Ok(names)
};
let encpub = match pull("age:key:*", "age:key:") { Ok(v) => v, Err(e)=> return Protocol::err(&e.0) };
let encpriv = match pull("age:privkey:*", "age:privkey:") { Ok(v) => v, Err(e)=> return Protocol::err(&e.0) };
let signpub = match pull("age:signpub:*", "age:signpub:") { Ok(v) => v, Err(e)=> return Protocol::err(&e.0) };
let signpriv= match pull("age:signpriv:*", "age:signpriv:") { Ok(v) => v, Err(e)=> return Protocol::err(&e.0) };
let to_arr = |label: &str, v: Vec<String>| {
let mut out = vec![Protocol::BulkString(label.to_string())];
out.push(Protocol::Array(v.into_iter().map(Protocol::BulkString).collect()));
Protocol::Array(out)
};
Protocol::Array(vec![
to_arr("encpub", encpub),
to_arr("encpriv", encpriv),
to_arr("signpub", signpub),
to_arr("signpriv", signpriv),
])
}
// ---------- Storage helpers ----------
fn sget(server: &Server, key: &str) -> Result<Option<String>, AgeWireError> {
let st = server.current_storage().map_err(|e| AgeWireError::Storage(e.0))?;
st.get(key).map_err(|e| AgeWireError::Storage(e.0))
}
fn sset(server: &Server, key: &str, val: &str) -> Result<(), AgeWireError> {
let st = server.current_storage().map_err(|e| AgeWireError::Storage(e.0))?;
st.set(key.to_string(), val.to_string()).map_err(|e| AgeWireError::Storage(e.0))
}
fn enc_pub_key_key(name: &str) -> String { format!("age:key:{name}") }
fn enc_priv_key_key(name: &str) -> String { format!("age:privkey:{name}") }
fn sign_pub_key_key(name: &str) -> String { format!("age:signpub:{name}") }
fn sign_priv_key_key(name: &str) -> String { format!("age:signpriv:{name}") }
// ---------- Command handlers (RESP Protocol) ----------
// Basic (stateless) ones kept for completeness
pub async fn cmd_age_genenc() -> Protocol {
let (recip, ident) = gen_enc_keypair();
Protocol::Array(vec![Protocol::BulkString(recip), Protocol::BulkString(ident)])
}
pub async fn cmd_age_gensign() -> Protocol {
let (verify, secret) = gen_sign_keypair();
Protocol::Array(vec![Protocol::BulkString(verify), Protocol::BulkString(secret)])
}
pub async fn cmd_age_encrypt(recipient: &str, message: &str) -> Protocol {
match encrypt_b64(recipient, message) {
Ok(b64) => Protocol::BulkString(b64),
Err(e) => e.to_protocol(),
}
}
pub async fn cmd_age_decrypt(identity: &str, ct_b64: &str) -> Protocol {
match decrypt_b64(identity, ct_b64) {
Ok(pt) => Protocol::BulkString(pt),
Err(e) => e.to_protocol(),
}
}
pub async fn cmd_age_sign(secret: &str, message: &str) -> Protocol {
match sign_b64(secret, message) {
Ok(b64sig) => Protocol::BulkString(b64sig),
Err(e) => e.to_protocol(),
}
}
pub async fn cmd_age_verify(verify_pub: &str, message: &str, sig_b64: &str) -> Protocol {
match verify_b64(verify_pub, message, sig_b64) {
Ok(true) => Protocol::SimpleString("1".to_string()),
Ok(false) => Protocol::SimpleString("0".to_string()),
Err(e) => e.to_protocol(),
}
}
// ---------- NEW: Persistent, named-key commands ----------
pub async fn cmd_age_keygen(server: &Server, name: &str) -> Protocol {
let (recip, ident) = gen_enc_keypair();
if let Err(e) = sset(server, &enc_pub_key_key(name), &recip) { return e.to_protocol(); }
if let Err(e) = sset(server, &enc_priv_key_key(name), &ident) { return e.to_protocol(); }
Protocol::Array(vec![Protocol::BulkString(recip), Protocol::BulkString(ident)])
}
pub async fn cmd_age_signkeygen(server: &Server, name: &str) -> Protocol {
let (verify, secret) = gen_sign_keypair();
if let Err(e) = sset(server, &sign_pub_key_key(name), &verify) { return e.to_protocol(); }
if let Err(e) = sset(server, &sign_priv_key_key(name), &secret) { return e.to_protocol(); }
Protocol::Array(vec![Protocol::BulkString(verify), Protocol::BulkString(secret)])
}
pub async fn cmd_age_encrypt_name(server: &Server, name: &str, message: &str) -> Protocol {
let recip = match sget(server, &enc_pub_key_key(name)) {
Ok(Some(v)) => v,
Ok(None) => return AgeWireError::NotFound("recipient (age:key:{name})").to_protocol(),
Err(e) => return e.to_protocol(),
};
match encrypt_b64(&recip, message) {
Ok(ct) => Protocol::BulkString(ct),
Err(e) => e.to_protocol(),
}
}
pub async fn cmd_age_decrypt_name(server: &Server, name: &str, ct_b64: &str) -> Protocol {
let ident = match sget(server, &enc_priv_key_key(name)) {
Ok(Some(v)) => v,
Ok(None) => return AgeWireError::NotFound("identity (age:privkey:{name})").to_protocol(),
Err(e) => return e.to_protocol(),
};
match decrypt_b64(&ident, ct_b64) {
Ok(pt) => Protocol::BulkString(pt),
Err(e) => e.to_protocol(),
}
}
pub async fn cmd_age_sign_name(server: &Server, name: &str, message: &str) -> Protocol {
let sec = match sget(server, &sign_priv_key_key(name)) {
Ok(Some(v)) => v,
Ok(None) => return AgeWireError::NotFound("signing secret (age:signpriv:{name})").to_protocol(),
Err(e) => return e.to_protocol(),
};
match sign_b64(&sec, message) {
Ok(sig) => Protocol::BulkString(sig),
Err(e) => e.to_protocol(),
}
}
pub async fn cmd_age_verify_name(server: &Server, name: &str, message: &str, sig_b64: &str) -> Protocol {
let pubk = match sget(server, &sign_pub_key_key(name)) {
Ok(Some(v)) => v,
Ok(None) => return AgeWireError::NotFound("verify pubkey (age:signpub:{name})").to_protocol(),
Err(e) => return e.to_protocol(),
};
match verify_b64(&pubk, message, sig_b64) {
Ok(true) => Protocol::SimpleString("1".to_string()),
Ok(false) => Protocol::SimpleString("0".to_string()),
Err(e) => e.to_protocol(),
}
}
pub async fn cmd_age_list(server: &Server) -> Protocol {
// Returns 4 arrays: ["encpub", <names...>], ["encpriv", ...], ["signpub", ...], ["signpriv", ...]
let st = match server.current_storage() { Ok(s) => s, Err(e) => return Protocol::err(&e.0) };
let pull = |pat: &str, prefix: &str| -> Result<Vec<String>, DBError> {
let keys = st.keys(pat)?;
let mut names: Vec<String> = keys.into_iter()
.filter_map(|k| k.strip_prefix(prefix).map(|x| x.to_string()))
.collect();
names.sort();
Ok(names)
};
let encpub = match pull("age:key:*", "age:key:") { Ok(v) => v, Err(e)=> return Protocol::err(&e.0) };
let encpriv = match pull("age:privkey:*", "age:privkey:") { Ok(v) => v, Err(e)=> return Protocol::err(&e.0) };
let signpub = match pull("age:signpub:*", "age:signpub:") { Ok(v) => v, Err(e)=> return Protocol::err(&e.0) };
let signpriv= match pull("age:signpriv:*", "age:signpriv:") { Ok(v) => v, Err(e)=> return Protocol::err(&e.0) };
let to_arr = |label: &str, v: Vec<String>| {
let mut out = vec![Protocol::BulkString(label.to_string())];
out.push(Protocol::Array(v.into_iter().map(Protocol::BulkString).collect()));
Protocol::Array(out)
};
Protocol::Array(vec![
to_arr("encpub", encpub),
to_arr("encpriv", encpriv),
to_arr("signpub", signpub),
to_arr("signpriv", signpriv),
])
}

973
crates/herodb/cmd.rs Normal file
View File

@@ -0,0 +1,973 @@
use crate::protocol::Protocol;
use crate::server::Server;
use libdbstorage::DBError;
use libcryptoa;
use serde::Serialize;
#[derive(Debug, Clone)]
pub enum Cmd {
Ping,
Echo(String),
Select(u64), // Changed from u16 to u64
Get(String),
Set(String, String),
SetPx(String, String, u128),
SetEx(String, String, u128),
Keys,
ConfigGet(String),
Info(Option<String>),
Del(String),
Type(String),
Incr(String),
Multi,
Exec,
Discard,
// Hash commands
HSet(String, Vec<(String, String)>),
HGet(String, String),
HGetAll(String),
HDel(String, Vec<String>),
HExists(String, String),
HKeys(String),
HVals(String),
HLen(String),
HMGet(String, Vec<String>),
HSetNx(String, String, String),
HScan(String, u64, Option<String>, Option<u64>), // key, cursor, pattern, count
Scan(u64, Option<String>, Option<u64>), // cursor, pattern, count
Ttl(String),
Exists(String),
Quit,
Client(Vec<String>),
ClientSetName(String),
ClientGetName,
// List commands
LPush(String, Vec<String>),
RPush(String, Vec<String>),
LPop(String, Option<u64>),
RPop(String, Option<u64>),
LLen(String),
LRem(String, i64, String),
LTrim(String, i64, i64),
LIndex(String, i64),
LRange(String, i64, i64),
FlushDb,
Unknow(String),
// AGE (rage) commands — stateless
AgeGenEnc,
AgeGenSign,
AgeEncrypt(String, String), // recipient, message
AgeDecrypt(String, String), // identity, ciphertext_b64
AgeSign(String, String), // signing_secret, message
AgeVerify(String, String, String), // verify_pub, message, signature_b64
// NEW: persistent named-key commands
AgeKeygen(String), // name
AgeSignKeygen(String), // name
AgeEncryptName(String, String), // name, message
AgeDecryptName(String, String), // name, ciphertext_b64
AgeSignName(String, String), // name, message
AgeVerifyName(String, String, String), // name, message, signature_b64
AgeList,
}
impl Cmd {
pub fn from(s: &str) -> Result<(Self, Protocol, &str), DBError> {
let (protocol, remaining) = Protocol::from(s)?;
match protocol.clone() {
Protocol::Array(p) => {
let cmd = p.into_iter().map(|x| x.decode()).collect::<Vec<_>>();
if cmd.is_empty() {
return Err(DBError("cmd length is 0".to_string()));
}
Ok((
match cmd[0].to_lowercase().as_str() {
"select" => {
if cmd.len() != 2 {
return Err(DBError("wrong number of arguments for SELECT".to_string()));
}
let idx = cmd[1].parse::<u64>().map_err(|_| DBError("ERR DB index is not an integer".to_string()))?;
Cmd::Select(idx)
}
"echo" => Cmd::Echo(cmd[1].clone()),
"ping" => Cmd::Ping,
"get" => Cmd::Get(cmd[1].clone()),
"set" => {
if cmd.len() == 5 && cmd[3].to_lowercase() == "px" {
Cmd::SetPx(cmd[1].clone(), cmd[2].clone(), cmd[4].parse().unwrap())
} else if cmd.len() == 5 && cmd[3].to_lowercase() == "ex" {
Cmd::SetEx(cmd[1].clone(), cmd[2].clone(), cmd[4].parse().unwrap())
} else if cmd.len() == 3 {
Cmd::Set(cmd[1].clone(), cmd[2].clone())
} else {
return Err(DBError(format!("unsupported cmd {:?}", cmd)));
}
}
"setex" => {
if cmd.len() != 4 {
return Err(DBError(format!("wrong number of arguments for SETEX command")));
}
Cmd::SetEx(cmd[1].clone(), cmd[3].clone(), cmd[2].parse().unwrap())
}
"config" => {
if cmd.len() != 3 || cmd[1].to_lowercase() != "get" {
return Err(DBError(format!("unsupported cmd {:?}", cmd)));
} else {
Cmd::ConfigGet(cmd[2].clone())
}
}
"keys" => {
if cmd.len() != 2 || cmd[1] != "*" {
return Err(DBError(format!("unsupported cmd {:?}", cmd)));
} else {
Cmd::Keys
}
}
"info" => {
let section = if cmd.len() == 2 {
Some(cmd[1].clone())
} else {
None
};
Cmd::Info(section)
}
"del" => {
if cmd.len() != 2 {
return Err(DBError(format!("unsupported cmd {:?}", cmd)));
}
Cmd::Del(cmd[1].clone())
}
"type" => {
if cmd.len() != 2 {
return Err(DBError(format!("unsupported cmd {:?}", cmd)));
}
Cmd::Type(cmd[1].clone())
}
"incr" => {
if cmd.len() != 2 {
return Err(DBError(format!("unsupported cmd {:?}", cmd)));
}
Cmd::Incr(cmd[1].clone())
}
"multi" => {
if cmd.len() != 1 {
return Err(DBError(format!("unsupported cmd {:?}", cmd)));
}
Cmd::Multi
}
"exec" => {
if cmd.len() != 1 {
return Err(DBError(format!("unsupported cmd {:?}", cmd)));
}
Cmd::Exec
}
"discard" => Cmd::Discard,
// Hash commands
"hset" => {
if cmd.len() < 4 || (cmd.len() - 2) % 2 != 0 {
return Err(DBError(format!("wrong number of arguments for HSET command")));
}
let mut pairs = Vec::new();
let mut i = 2;
while i + 1 < cmd.len() {
pairs.push((cmd[i].clone(), cmd[i + 1].clone()));
i += 2;
}
Cmd::HSet(cmd[1].clone(), pairs)
}
"hget" => {
if cmd.len() != 3 {
return Err(DBError(format!("wrong number of arguments for HGET command")));
}
Cmd::HGet(cmd[1].clone(), cmd[2].clone())
}
"hgetall" => {
if cmd.len() != 2 {
return Err(DBError(format!("wrong number of arguments for HGETALL command")));
}
Cmd::HGetAll(cmd[1].clone())
}
"hdel" => {
if cmd.len() < 3 {
return Err(DBError(format!("wrong number of arguments for HDEL command")));
}
Cmd::HDel(cmd[1].clone(), cmd[2..].to_vec())
}
"hexists" => {
if cmd.len() != 3 {
return Err(DBError(format!("wrong number of arguments for HEXISTS command")));
}
Cmd::HExists(cmd[1].clone(), cmd[2].clone())
}
"hkeys" => {
if cmd.len() != 2 {
return Err(DBError(format!("wrong number of arguments for HKEYS command")));
}
Cmd::HKeys(cmd[1].clone())
}
"hvals" => {
if cmd.len() != 2 {
return Err(DBError(format!("wrong number of arguments for HVALS command")));
}
Cmd::HVals(cmd[1].clone())
}
"hlen" => {
if cmd.len() != 2 {
return Err(DBError(format!("wrong number of arguments for HLEN command")));
}
Cmd::HLen(cmd[1].clone())
}
"hmget" => {
if cmd.len() < 3 {
return Err(DBError(format!("wrong number of arguments for HMGET command")));
}
Cmd::HMGet(cmd[1].clone(), cmd[2..].to_vec())
}
"hsetnx" => {
if cmd.len() != 4 {
return Err(DBError(format!("wrong number of arguments for HSETNX command")));
}
Cmd::HSetNx(cmd[1].clone(), cmd[2].clone(), cmd[3].clone())
}
"hscan" => {
if cmd.len() < 3 {
return Err(DBError(format!("wrong number of arguments for HSCAN command")));
}
let key = cmd[1].clone();
let cursor = cmd[2].parse::<u64>().map_err(|_|
DBError("ERR invalid cursor".to_string()))?;
let mut pattern = None;
let mut count = None;
let mut i = 3;
while i < cmd.len() {
match cmd[i].to_lowercase().as_str() {
"match" => {
if i + 1 >= cmd.len() {
return Err(DBError("ERR syntax error".to_string()));
}
pattern = Some(cmd[i + 1].clone());
i += 2;
}
"count" => {
if i + 1 >= cmd.len() {
return Err(DBError("ERR syntax error".to_string()));
}
count = Some(cmd[i + 1].parse::<u64>().map_err(|_|
DBError("ERR value is not an integer or out of range".to_string()))?);
i += 2;
}
_ => {
return Err(DBError(format!("ERR syntax error")));
}
}
}
Cmd::HScan(key, cursor, pattern, count)
}
"scan" => {
if cmd.len() < 2 {
return Err(DBError(format!("wrong number of arguments for SCAN command")));
}
let cursor = cmd[1].parse::<u64>().map_err(|_|
DBError("ERR invalid cursor".to_string()))?;
let mut pattern = None;
let mut count = None;
let mut i = 2;
while i < cmd.len() {
match cmd[i].to_lowercase().as_str() {
"match" => {
if i + 1 >= cmd.len() {
return Err(DBError("ERR syntax error".to_string()));
}
pattern = Some(cmd[i + 1].clone());
i += 2;
}
"count" => {
if i + 1 >= cmd.len() {
return Err(DBError("ERR syntax error".to_string()));
}
count = Some(cmd[i + 1].parse::<u64>().map_err(|_|
DBError("ERR value is not an integer or out of range".to_string()))?);
i += 2;
}
_ => {
return Err(DBError(format!("ERR syntax error")));
}
}
}
Cmd::Scan(cursor, pattern, count)
}
"ttl" => {
if cmd.len() != 2 {
return Err(DBError(format!("wrong number of arguments for TTL command")));
}
Cmd::Ttl(cmd[1].clone())
}
"exists" => {
if cmd.len() != 2 {
return Err(DBError(format!("wrong number of arguments for EXISTS command")));
}
Cmd::Exists(cmd[1].clone())
}
"quit" => {
if cmd.len() != 1 {
return Err(DBError(format!("wrong number of arguments for QUIT command")));
}
Cmd::Quit
}
"client" => {
if cmd.len() > 1 {
match cmd[1].to_lowercase().as_str() {
"setname" => {
if cmd.len() == 3 {
Cmd::ClientSetName(cmd[2].clone())
} else {
return Err(DBError("wrong number of arguments for 'client setname' command".to_string()));
}
}
"getname" => {
if cmd.len() == 2 {
Cmd::ClientGetName
} else {
return Err(DBError("wrong number of arguments for 'client getname' command".to_string()));
}
}
_ => Cmd::Client(cmd[1..].to_vec()),
}
} else {
Cmd::Client(vec![])
}
}
"lpush" => {
if cmd.len() < 3 {
return Err(DBError(format!("wrong number of arguments for LPUSH command")));
}
Cmd::LPush(cmd[1].clone(), cmd[2..].to_vec())
}
"rpush" => {
if cmd.len() < 3 {
return Err(DBError(format!("wrong number of arguments for RPUSH command")));
}
Cmd::RPush(cmd[1].clone(), cmd[2..].to_vec())
}
"lpop" => {
if cmd.len() < 2 || cmd.len() > 3 {
return Err(DBError(format!("wrong number of arguments for LPOP command")));
}
let count = if cmd.len() == 3 {
Some(cmd[2].parse::<u64>().map_err(|_| DBError("ERR value is not an integer or out of range".to_string()))?)
} else {
None
};
Cmd::LPop(cmd[1].clone(), count)
}
"rpop" => {
if cmd.len() < 2 || cmd.len() > 3 {
return Err(DBError(format!("wrong number of arguments for RPOP command")));
}
let count = if cmd.len() == 3 {
Some(cmd[2].parse::<u64>().map_err(|_| DBError("ERR value is not an integer or out of range".to_string()))?)
} else {
None
};
Cmd::RPop(cmd[1].clone(), count)
}
"llen" => {
if cmd.len() != 2 {
return Err(DBError(format!("wrong number of arguments for LLEN command")));
}
Cmd::LLen(cmd[1].clone())
}
"lrem" => {
if cmd.len() != 4 {
return Err(DBError(format!("wrong number of arguments for LREM command")));
}
let count = cmd[2].parse::<i64>().map_err(|_| DBError("ERR value is not an integer or out of range".to_string()))?;
Cmd::LRem(cmd[1].clone(), count, cmd[3].clone())
}
"ltrim" => {
if cmd.len() != 4 {
return Err(DBError(format!("wrong number of arguments for LTRIM command")));
}
let start = cmd[2].parse::<i64>().map_err(|_| DBError("ERR value is not an integer or out of range".to_string()))?;
let stop = cmd[3].parse::<i64>().map_err(|_| DBError("ERR value is not an integer or out of range".to_string()))?;
Cmd::LTrim(cmd[1].clone(), start, stop)
}
"lindex" => {
if cmd.len() != 3 {
return Err(DBError(format!("wrong number of arguments for LINDEX command")));
}
let index = cmd[2].parse::<i64>().map_err(|_| DBError("ERR value is not an integer or out of range".to_string()))?;
Cmd::LIndex(cmd[1].clone(), index)
}
"lrange" => {
if cmd.len() != 4 {
return Err(DBError(format!("wrong number of arguments for LRANGE command")));
}
let start = cmd[2].parse::<i64>().map_err(|_| DBError("ERR value is not an integer or out of range".to_string()))?;
let stop = cmd[3].parse::<i64>().map_err(|_| DBError("ERR value is not an integer or out of range".to_string()))?;
Cmd::LRange(cmd[1].clone(), start, stop)
}
"flushdb" => {
if cmd.len() != 1 {
return Err(DBError("wrong number of arguments for FLUSHDB command".to_string()));
}
Cmd::FlushDb
}
"age" => {
if cmd.len() < 2 {
return Err(DBError("wrong number of arguments for AGE".to_string()));
}
match cmd[1].to_lowercase().as_str() {
// stateless
"genenc" => { if cmd.len() != 2 { return Err(DBError("AGE GENENC takes no args".to_string())); }
Cmd::AgeGenEnc }
"gensign" => { if cmd.len() != 2 { return Err(DBError("AGE GENSIGN takes no args".to_string())); }
Cmd::AgeGenSign }
"encrypt" => { if cmd.len() != 4 { return Err(DBError("AGE ENCRYPT <recipient> <message>".to_string())); }
Cmd::AgeEncrypt(cmd[2].clone(), cmd[3].clone()) }
"decrypt" => { if cmd.len() != 4 { return Err(DBError("AGE DECRYPT <identity> <ciphertext_b64>".to_string())); }
Cmd::AgeDecrypt(cmd[2].clone(), cmd[3].clone()) }
"sign" => { if cmd.len() != 4 { return Err(DBError("AGE SIGN <signing_secret> <message>".to_string())); }
Cmd::AgeSign(cmd[2].clone(), cmd[3].clone()) }
"verify" => { if cmd.len() != 5 { return Err(DBError("AGE VERIFY <verify_pub> <message> <signature_b64>".to_string())); }
Cmd::AgeVerify(cmd[2].clone(), cmd[3].clone(), cmd[4].clone()) }
// persistent names
"keygen" => { if cmd.len() != 3 { return Err(DBError("AGE KEYGEN <name>".to_string())); }
Cmd::AgeKeygen(cmd[2].clone()) }
"signkeygen" => { if cmd.len() != 3 { return Err(DBError("AGE SIGNKEYGEN <name>".to_string())); }
Cmd::AgeSignKeygen(cmd[2].clone()) }
"encryptname" => { if cmd.len() != 4 { return Err(DBError("AGE ENCRYPTNAME <name> <message>".to_string())); }
Cmd::AgeEncryptName(cmd[2].clone(), cmd[3].clone()) }
"decryptname" => { if cmd.len() != 4 { return Err(DBError("AGE DECRYPTNAME <name> <ciphertext_b64>".to_string())); }
Cmd::AgeDecryptName(cmd[2].clone(), cmd[3].clone()) }
"signname" => { if cmd.len() != 4 { return Err(DBError("AGE SIGNNAME <name> <message>".to_string())); }
Cmd::AgeSignName(cmd[2].clone(), cmd[3].clone()) }
"verifyname" => { if cmd.len() != 5 { return Err(DBError("AGE VERIFYNAME <name> <message> <signature_b64>".to_string())); }
Cmd::AgeVerifyName(cmd[2].clone(), cmd[3].clone(), cmd[4].clone()) }
"list" => { if cmd.len() != 2 { return Err(DBError("AGE LIST".to_string())); }
Cmd::AgeList }
_ => return Err(DBError(format!("unsupported AGE subcommand {:?}", cmd))),
}
}
_ => Cmd::Unknow(cmd[0].clone()),
},
protocol,
remaining
))
}
_ => Err(DBError(format!(
"fail to parse as cmd for {:?}",
protocol
))),
}
}
pub async fn run(self, server: &mut Server) -> Result<Protocol, DBError> {
// Handle queued commands for transactions
if server.queued_cmd.is_some()
&& !matches!(self, Cmd::Exec)
&& !matches!(self, Cmd::Multi)
&& !matches!(self, Cmd::Discard)
{
let protocol = self.clone().to_protocol();
server.queued_cmd.as_mut().unwrap().push((self, protocol));
return Ok(Protocol::SimpleString("QUEUED".to_string()));
}
match self {
Cmd::Select(db) => select_cmd(server, db).await,
Cmd::Ping => Ok(Protocol::SimpleString("PONG".to_string())),
Cmd::Echo(s) => Ok(Protocol::BulkString(s)),
Cmd::Get(k) => get_cmd(server, &k).await,
Cmd::Set(k, v) => set_cmd(server, &k, &v).await,
Cmd::SetPx(k, v, x) => set_px_cmd(server, &k, &v, &x).await,
Cmd::SetEx(k, v, x) => set_ex_cmd(server, &k, &v, &x).await,
Cmd::Del(k) => del_cmd(server, &k).await,
Cmd::ConfigGet(name) => config_get_cmd(&name, server),
Cmd::Keys => keys_cmd(server).await,
Cmd::Info(section) => info_cmd(server, &section).await,
Cmd::Type(k) => type_cmd(server, &k).await,
Cmd::Incr(key) => incr_cmd(server, &key).await,
Cmd::Multi => {
server.queued_cmd = Some(Vec::<(Cmd, Protocol)>::new());
Ok(Protocol::SimpleString("OK".to_string()))
}
Cmd::Exec => exec_cmd(server).await,
Cmd::Discard => {
if server.queued_cmd.is_some() {
server.queued_cmd = None;
Ok(Protocol::SimpleString("OK".to_string()))
} else {
Ok(Protocol::err("ERR DISCARD without MULTI"))
}
}
// Hash commands
Cmd::HSet(key, pairs) => hset_cmd(server, &key, &pairs).await,
Cmd::HGet(key, field) => hget_cmd(server, &key, &field).await,
Cmd::HGetAll(key) => hgetall_cmd(server, &key).await,
Cmd::HDel(key, fields) => hdel_cmd(server, &key, &fields).await,
Cmd::HExists(key, field) => hexists_cmd(server, &key, &field).await,
Cmd::HKeys(key) => hkeys_cmd(server, &key).await,
Cmd::HVals(key) => hvals_cmd(server, &key).await,
Cmd::HLen(key) => hlen_cmd(server, &key).await,
Cmd::HMGet(key, fields) => hmget_cmd(server, &key, &fields).await,
Cmd::HSetNx(key, field, value) => hsetnx_cmd(server, &key, &field, &value).await,
Cmd::HScan(key, cursor, pattern, count) => hscan_cmd(server, &key, &cursor, pattern.as_deref(), &count).await,
Cmd::Scan(cursor, pattern, count) => scan_cmd(server, &cursor, pattern.as_deref(), &count).await,
Cmd::Ttl(key) => ttl_cmd(server, &key).await,
Cmd::Exists(key) => exists_cmd(server, &key).await,
Cmd::Quit => Ok(Protocol::SimpleString("OK".to_string())),
Cmd::Client(_) => Ok(Protocol::SimpleString("OK".to_string())),
Cmd::ClientSetName(name) => client_setname_cmd(server, &name).await,
Cmd::ClientGetName => client_getname_cmd(server).await,
// List commands
Cmd::LPush(key, elements) => lpush_cmd(server, &key, &elements).await,
Cmd::RPush(key, elements) => rpush_cmd(server, &key, &elements).await,
Cmd::LPop(key, count) => lpop_cmd(server, &key, &count).await,
Cmd::RPop(key, count) => rpop_cmd(server, &key, &count).await,
Cmd::LLen(key) => llen_cmd(server, &key).await,
Cmd::LRem(key, count, element) => lrem_cmd(server, &key, count, &element).await,
Cmd::LTrim(key, start, stop) => ltrim_cmd(server, &key, start, stop).await,
Cmd::LIndex(key, index) => lindex_cmd(server, &key, index).await,
Cmd::LRange(key, start, stop) => lrange_cmd(server, &key, start, stop).await,
Cmd::FlushDb => flushdb_cmd(server).await,
// AGE (rage): stateless
Cmd::AgeGenEnc => Ok(libcryptoa::gen_enc_keypair().await),
Cmd::AgeGenSign => Ok(libcryptoa::gen_sign_keypair().await),
Cmd::AgeEncrypt(recipient, message) => Ok(libcryptoa::encrypt_b64(&recipient, &message).await),
Cmd::AgeDecrypt(identity, ct_b64) => Ok(libcryptoa::decrypt_b64(&identity, &ct_b64).await),
Cmd::AgeSign(secret, message) => Ok(libcryptoa::sign_b64(&secret, &message).await),
Cmd::AgeVerify(vpub, msg, sig_b64) => Ok(libcryptoa::verify_b64(&vpub, &msg, &sig_b64).await),
// AGE (rage): persistent named keys
Cmd::AgeKeygen(name) => Ok(crate::age::cmd_age_keygen(server, &name).await),
Cmd::AgeSignKeygen(name) => Ok(crate::age::cmd_age_signkeygen(server, &name).await),
Cmd::AgeEncryptName(name, message) => Ok(crate::age::cmd_age_encrypt_name(server, &name, &message).await),
Cmd::AgeDecryptName(name, ct_b64) => Ok(crate::age::cmd_age_decrypt_name(server, &name, &ct_b64).await),
Cmd::AgeSignName(name, message) => Ok(crate::age::cmd_age_sign_name(server, &name, &message).await),
Cmd::AgeVerifyName(name, message, sig_b64) => Ok(crate::age::cmd_age_verify_name(server, &name, &message, &sig_b64).await),
Cmd::AgeList => Ok(crate::age::cmd_age_list(server).await),
Cmd::Unknow(s) => Ok(Protocol::err(&format!("ERR unknown command `{}`", s))),
}
}
pub fn to_protocol(self) -> Protocol {
match self {
Cmd::Select(db) => Protocol::Array(vec![Protocol::BulkString("select".to_string()), Protocol::BulkString(db.to_string())]),
Cmd::Ping => Protocol::Array(vec![Protocol::BulkString("ping".to_string())]),
Cmd::Echo(s) => Protocol::Array(vec![Protocol::BulkString("echo".to_string()), Protocol::BulkString(s)]),
Cmd::Get(k) => Protocol::Array(vec![Protocol::BulkString("get".to_string()), Protocol::BulkString(k)]),
Cmd::Set(k, v) => Protocol::Array(vec![Protocol::BulkString("set".to_string()), Protocol::BulkString(k), Protocol::BulkString(v)]),
_ => Protocol::SimpleString("...".to_string())
}
}
}
async fn flushdb_cmd(server: &mut Server) -> Result<Protocol, DBError> {
match server.current_storage()?.flushdb() {
Ok(_) => Ok(Protocol::SimpleString("OK".to_string())),
Err(e) => Ok(Protocol::err(&e.0)),
}
}
async fn select_cmd(server: &mut Server, db: u64) -> Result<Protocol, DBError> {
// Test if we can access the database (this will create it if needed)
server.selected_db = db;
match server.current_storage() {
Ok(_) => Ok(Protocol::SimpleString("OK".to_string())),
Err(e) => Ok(Protocol::err(&e.0)),
}
}
async fn lindex_cmd(server: &Server, key: &str, index: i64) -> Result<Protocol, DBError> {
match server.current_storage()?.lindex(key, index) {
Ok(Some(element)) => Ok(Protocol::BulkString(element)),
Ok(None) => Ok(Protocol::Null),
Err(e) => Ok(Protocol::err(&e.0)),
}
}
async fn lrange_cmd(server: &Server, key: &str, start: i64, stop: i64) -> Result<Protocol, DBError> {
match server.current_storage()?.lrange(key, start, stop) {
Ok(elements) => Ok(Protocol::Array(elements.into_iter().map(Protocol::BulkString).collect())),
Err(e) => Ok(Protocol::err(&e.0)),
}
}
async fn ltrim_cmd(server: &Server, key: &str, start: i64, stop: i64) -> Result<Protocol, DBError> {
match server.current_storage()?.ltrim(key, start, stop) {
Ok(_) => Ok(Protocol::SimpleString("OK".to_string())),
Err(e) => Ok(Protocol::err(&e.0)),
}
}
async fn lrem_cmd(server: &Server, key: &str, count: i64, element: &str) -> Result<Protocol, DBError> {
match server.current_storage()?.lrem(key, count, element) {
Ok(removed_count) => Ok(Protocol::SimpleString(removed_count.to_string())),
Err(e) => Ok(Protocol::err(&e.0)),
}
}
async fn llen_cmd(server: &Server, key: &str) -> Result<Protocol, DBError> {
match server.current_storage()?.llen(key) {
Ok(len) => Ok(Protocol::SimpleString(len.to_string())),
Err(e) => Ok(Protocol::err(&e.0)),
}
}
async fn lpop_cmd(server: &Server, key: &str, count: &Option<u64>) -> Result<Protocol, DBError> {
let count_val = count.unwrap_or(1);
match server.current_storage()?.lpop(key, count_val) {
Ok(elements) => {
if elements.is_empty() {
if count.is_some() {
Ok(Protocol::Array(vec![]))
} else {
Ok(Protocol::Null)
}
} else if count.is_some() {
Ok(Protocol::Array(elements.into_iter().map(Protocol::BulkString).collect()))
} else {
Ok(Protocol::BulkString(elements[0].clone()))
}
},
Err(e) => Ok(Protocol::err(&e.0)),
}
}
async fn rpop_cmd(server: &Server, key: &str, count: &Option<u64>) -> Result<Protocol, DBError> {
let count_val = count.unwrap_or(1);
match server.current_storage()?.rpop(key, count_val) {
Ok(elements) => {
if elements.is_empty() {
if count.is_some() {
Ok(Protocol::Array(vec![]))
} else {
Ok(Protocol::Null)
}
} else if count.is_some() {
Ok(Protocol::Array(elements.into_iter().map(Protocol::BulkString).collect()))
} else {
Ok(Protocol::BulkString(elements[0].clone()))
}
},
Err(e) => Ok(Protocol::err(&e.0)),
}
}
async fn lpush_cmd(server: &Server, key: &str, elements: &[String]) -> Result<Protocol, DBError> {
match server.current_storage()?.lpush(key, elements.to_vec()) {
Ok(len) => Ok(Protocol::SimpleString(len.to_string())),
Err(e) => Ok(Protocol::err(&e.0)),
}
}
async fn rpush_cmd(server: &Server, key: &str, elements: &[String]) -> Result<Protocol, DBError> {
match server.current_storage()?.rpush(key, elements.to_vec()) {
Ok(len) => Ok(Protocol::SimpleString(len.to_string())),
Err(e) => Ok(Protocol::err(&e.0)),
}
}
async fn exec_cmd(server: &mut Server) -> Result<Protocol, DBError> {
// Move the queued commands out of `server` so we drop the borrow immediately.
let cmds = if let Some(cmds) = server.queued_cmd.take() {
cmds
} else {
return Ok(Protocol::err("ERR EXEC without MULTI"));
};
let mut out = Vec::new();
for (cmd, _) in cmds {
// Use Box::pin to handle recursion in async function
let res = Box::pin(cmd.run(server)).await?;
out.push(res);
}
Ok(Protocol::Array(out))
}
async fn incr_cmd(server: &Server, key: &String) -> Result<Protocol, DBError> {
let storage = server.current_storage()?;
let current_value = storage.get(key)?;
let new_value = match current_value {
Some(v) => {
match v.parse::<i64>() {
Ok(num) => num + 1,
Err(_) => return Ok(Protocol::err("ERR value is not an integer or out of range")),
}
}
None => 1,
};
storage.set(key.clone(), new_value.to_string())?;
Ok(Protocol::SimpleString(new_value.to_string()))
}
fn config_get_cmd(name: &String, server: &Server) -> Result<Protocol, DBError> {
let value = match name.as_str() {
"dir" => Some(server.option.dir.clone()),
"dbfilename" => Some(format!("{}.db", server.selected_db)),
"databases" => Some("16".to_string()), // Hardcoded as per original logic
_ => None,
};
if let Some(val) = value {
Ok(Protocol::Array(vec![
Protocol::BulkString(name.clone()),
Protocol::BulkString(val),
]))
} else {
// Return an empty array for unknown config options, which is standard Redis behavior
Ok(Protocol::Array(vec![]))
}
}
async fn keys_cmd(server: &Server) -> Result<Protocol, DBError> {
let keys = server.current_storage()?.keys("*")?;
Ok(Protocol::Array(
keys.into_iter().map(Protocol::BulkString).collect(),
))
}
#[derive(Serialize)]
struct ServerInfo {
redis_version: String,
encrypted: bool,
selected_db: u64,
}
async fn info_cmd(server: &Server, section: &Option<String>) -> Result<Protocol, DBError> {
let info = ServerInfo {
redis_version: "7.0.0".to_string(),
encrypted: server.current_storage()?.is_encrypted(),
selected_db: server.selected_db,
};
let mut info_string = String::new();
info_string.push_str(&format!("# Server\n"));
info_string.push_str(&format!("redis_version:{}\n", info.redis_version));
info_string.push_str(&format!("encrypted:{}\n", if info.encrypted { 1 } else { 0 }));
info_string.push_str(&format!("# Keyspace\n"));
info_string.push_str(&format!("db{}:keys=0,expires=0,avg_ttl=0\n", info.selected_db));
match section {
Some(s) => match s.as_str() {
"replication" => Ok(Protocol::BulkString(
"role:master\nmaster_replid:8371b4fb1155b71f4a04d3e1bc3e18c4a990aeea\nmaster_repl_offset:0\n".to_string()
)),
_ => Err(DBError(format!("unsupported section {:?}", s))),
},
None => {
Ok(Protocol::BulkString(info_string))
}
}
}
async fn type_cmd(server: &Server, k: &String) -> Result<Protocol, DBError> {
match server.current_storage()?.get_key_type(k)? {
Some(type_str) => Ok(Protocol::SimpleString(type_str)),
None => Ok(Protocol::SimpleString("none".to_string())),
}
}
async fn del_cmd(server: &Server, k: &str) -> Result<Protocol, DBError> {
server.current_storage()?.del(k.to_string())?;
Ok(Protocol::SimpleString("1".to_string()))
}
async fn set_ex_cmd(
server: &Server,
k: &str,
v: &str,
x: &u128,
) -> Result<Protocol, DBError> {
server.current_storage()?.setx(k.to_string(), v.to_string(), *x * 1000)?;
Ok(Protocol::SimpleString("OK".to_string()))
}
async fn set_px_cmd(
server: &Server,
k: &str,
v: &str,
x: &u128,
) -> Result<Protocol, DBError> {
server.current_storage()?.setx(k.to_string(), v.to_string(), *x)?;
Ok(Protocol::SimpleString("OK".to_string()))
}
async fn set_cmd(server: &Server, k: &str, v: &str) -> Result<Protocol, DBError> {
server.current_storage()?.set(k.to_string(), v.to_string())?;
Ok(Protocol::SimpleString("OK".to_string()))
}
async fn get_cmd(server: &Server, k: &str) -> Result<Protocol, DBError> {
let v = server.current_storage()?.get(k)?;
Ok(v.map_or(Protocol::Null, Protocol::BulkString))
}
// Hash command implementations
async fn hset_cmd(server: &Server, key: &str, pairs: &[(String, String)]) -> Result<Protocol, DBError> {
let new_fields = server.current_storage()?.hset(key, pairs.to_vec())?;
Ok(Protocol::SimpleString(new_fields.to_string()))
}
async fn hget_cmd(server: &Server, key: &str, field: &str) -> Result<Protocol, DBError> {
match server.current_storage()?.hget(key, field) {
Ok(Some(value)) => Ok(Protocol::BulkString(value)),
Ok(None) => Ok(Protocol::Null),
Err(e) => Ok(Protocol::err(&e.0)),
}
}
async fn hgetall_cmd(server: &Server, key: &str) -> Result<Protocol, DBError> {
match server.current_storage()?.hgetall(key) {
Ok(pairs) => {
let mut result = Vec::new();
for (field, value) in pairs {
result.push(Protocol::BulkString(field));
result.push(Protocol::BulkString(value));
}
Ok(Protocol::Array(result))
}
Err(e) => Ok(Protocol::err(&e.0)),
}
}
async fn hdel_cmd(server: &Server, key: &str, fields: &[String]) -> Result<Protocol, DBError> {
match server.current_storage()?.hdel(key, fields.to_vec()) {
Ok(deleted) => Ok(Protocol::SimpleString(deleted.to_string())),
Err(e) => Ok(Protocol::err(&e.0)),
}
}
async fn hexists_cmd(server: &Server, key: &str, field: &str) -> Result<Protocol, DBError> {
match server.current_storage()?.hexists(key, field) {
Ok(exists) => Ok(Protocol::SimpleString(if exists { "1" } else { "0" }.to_string())),
Err(e) => Ok(Protocol::err(&e.0)),
}
}
async fn hkeys_cmd(server: &Server, key: &str) -> Result<Protocol, DBError> {
match server.current_storage()?.hkeys(key) {
Ok(keys) => Ok(Protocol::Array(
keys.into_iter().map(Protocol::BulkString).collect(),
)),
Err(e) => Ok(Protocol::err(&e.0)),
}
}
async fn hvals_cmd(server: &Server, key: &str) -> Result<Protocol, DBError> {
match server.current_storage()?.hvals(key) {
Ok(values) => Ok(Protocol::Array(
values.into_iter().map(Protocol::BulkString).collect(),
)),
Err(e) => Ok(Protocol::err(&e.0)),
}
}
async fn hlen_cmd(server: &Server, key: &str) -> Result<Protocol, DBError> {
match server.current_storage()?.hlen(key) {
Ok(len) => Ok(Protocol::SimpleString(len.to_string())),
Err(e) => Ok(Protocol::err(&e.0)),
}
}
async fn hmget_cmd(server: &Server, key: &str, fields: &[String]) -> Result<Protocol, DBError> {
match server.current_storage()?.hmget(key, fields.to_vec()) {
Ok(values) => {
let result: Vec<Protocol> = values
.into_iter()
.map(|v| v.map_or(Protocol::Null, Protocol::BulkString))
.collect();
Ok(Protocol::Array(result))
}
Err(e) => Ok(Protocol::err(&e.0)),
}
}
async fn hsetnx_cmd(server: &Server, key: &str, field: &str, value: &str) -> Result<Protocol, DBError> {
match server.current_storage()?.hsetnx(key, field, value) {
Ok(was_set) => Ok(Protocol::SimpleString(if was_set { "1" } else { "0" }.to_string())),
Err(e) => Ok(Protocol::err(&e.0)),
}
}
async fn scan_cmd(
server: &Server,
cursor: &u64,
pattern: Option<&str>,
count: &Option<u64>
) -> Result<Protocol, DBError> {
match server.current_storage()?.scan(*cursor, pattern, *count) {
Ok((next_cursor, key_value_pairs)) => {
let mut result = Vec::new();
result.push(Protocol::BulkString(next_cursor.to_string()));
// For SCAN, we only return the keys, not the values
let keys: Vec<Protocol> = key_value_pairs.into_iter().map(|(key, _)| Protocol::BulkString(key)).collect();
result.push(Protocol::Array(keys));
Ok(Protocol::Array(result))
}
Err(e) => Ok(Protocol::err(&format!("ERR {}", e.0))),
}
}
async fn hscan_cmd(
server: &Server,
key: &str,
cursor: &u64,
pattern: Option<&str>,
count: &Option<u64>
) -> Result<Protocol, DBError> {
match server.current_storage()?.hscan(key, *cursor, pattern, *count) {
Ok((next_cursor, field_value_pairs)) => {
let mut result = Vec::new();
result.push(Protocol::BulkString(next_cursor.to_string()));
// For HSCAN, we return field-value pairs flattened
let mut fields_and_values = Vec::new();
for (field, value) in field_value_pairs {
fields_and_values.push(Protocol::BulkString(field));
fields_and_values.push(Protocol::BulkString(value));
}
result.push(Protocol::Array(fields_and_values));
Ok(Protocol::Array(result))
}
Err(e) => Ok(Protocol::err(&format!("ERR {}", e.0))),
}
}
async fn ttl_cmd(server: &Server, key: &str) -> Result<Protocol, DBError> {
match server.current_storage()?.ttl(key) {
Ok(ttl) => Ok(Protocol::SimpleString(ttl.to_string())),
Err(e) => Ok(Protocol::err(&e.0)),
}
}
async fn exists_cmd(server: &Server, key: &str) -> Result<Protocol, DBError> {
match server.current_storage()?.exists(key) {
Ok(exists) => Ok(Protocol::SimpleString(if exists { "1" } else { "0" }.to_string())),
Err(e) => Ok(Protocol::err(&e.0)),
}
}
async fn client_setname_cmd(server: &mut Server, name: &str) -> Result<Protocol, DBError> {
server.client_name = Some(name.to_string());
Ok(Protocol::SimpleString("OK".to_string()))
}
async fn client_getname_cmd(server: &Server) -> Result<Protocol, DBError> {
match &server.client_name {
Some(name) => Ok(Protocol::BulkString(name.clone())),
None => Ok(Protocol::Null),
}
}

View File

@@ -0,0 +1,71 @@
#!/bin/bash
# Start the herodb server in the background
echo "Starting herodb server..."
cargo run -p herodb -- --dir /tmp/herodb_age_test --port 6382 --debug --encryption-key "testkey" &
SERVER_PID=$!
sleep 2 # Give the server a moment to start
REDIS_CLI="redis-cli -p 6382"
echo "--- Generating and Storing Encryption Keys ---"
# The new AGE commands are 'AGE KEYGEN <name>' etc., based on src/cmd.rs
# This script uses older commands like 'AGE.GENERATE_KEYPAIR alice'
# The demo script needs to be updated to match the implemented commands.
# Let's assume the commands in the script are what's expected for now,
# but note this discrepancy. The new commands are AGE KEYGEN etc.
# The script here uses a different syntax not found in src/cmd.rs like 'AGE.GENERATE_KEYPAIR'.
# For now, I will modify the script to fit the actual implementation.
echo "--- Generating and Storing Encryption Keys ---"
$REDIS_CLI AGE KEYGEN alice
$REDIS_CLI AGE KEYGEN bob
echo "--- Encrypting and Decrypting a Message ---"
MESSAGE="Hello, AGE encryption!"
# The new logic stores keys internally and does not expose a command to get the public key.
# We will encrypt by name.
ALICE_PUBKEY_REPLY=$($REDIS_CLI AGE KEYGEN alice | head -n 2 | tail -n 1)
echo "Alice's Public Key: $ALICE_PUBKEY_REPLY"
echo "Encrypting message: '$MESSAGE' with Alice's identity..."
# AGE.ENCRYPT recipient message. But since we use persistent keys, let's use ENCRYPTNAME
CIPHERTEXT=$($REDIS_CLI AGE ENCRYPTNAME alice "$MESSAGE")
echo "Ciphertext: $CIPHERTEXT"
echo "Decrypting ciphertext with Alice's private key..."
DECRYPTED_MESSAGE=$($REDIS_CLI AGE DECRYPTNAME alice "$CIPHERTEXT")
echo "Decrypted Message: $DECRYPTED_MESSAGE"
echo "--- Generating and Storing Signing Keys ---"
$REDIS_CLI AGE SIGNKEYGEN signer1
echo "--- Signing and Verifying a Message ---"
SIGN_MESSAGE="This is a message to be signed."
# Similar to above, we don't have GET_SIGN_PUBKEY. We will verify by name.
echo "Signing message: '$SIGN_MESSAGE' with signer1's private key..."
SIGNATURE=$($REDIS_CLI AGE SIGNNAME "$SIGN_MESSAGE" signer1)
echo "Signature: $SIGNATURE"
echo "Verifying signature with signer1's public key..."
VERIFY_RESULT=$($REDIS_CLI AGE VERIFYNAME signer1 "$SIGN_MESSAGE" "$SIGNATURE")
echo "Verification Result: $VERIFY_RESULT"
# There is no DELETE_KEYPAIR command in the implementation
echo "--- Cleaning up keys (manual in herodb) ---"
# We would use DEL for age:key:alice, etc.
$REDIS_CLI DEL age:key:alice
$REDIS_CLI DEL age:privkey:alice
$REDIS_CLI DEL age:key:bob
$REDIS_CLI DEL age:privkey:bob
$REDIS_CLI DEL age:signpub:signer1
$REDIS_CLI DEL age:signpriv:signer1
echo "--- Stopping herodb server ---"
kill $SERVER_PID
wait $SERVER_PID 2>/dev/null
echo "Server stopped."
echo "Bash demo complete."

View File

@@ -0,0 +1,83 @@
use std::io::{Read, Write};
use std::net::TcpStream;
// Minimal RESP helpers
fn arr(parts: &[&str]) -> String {
let mut out = format!("*{}\r\n", parts.len());
for p in parts {
out.push_str(&format!("${}\r\n{}\r\n", p.len(), p));
}
out
}
fn read_reply(s: &mut TcpStream) -> String {
let mut buf = [0u8; 65536];
let n = s.read(&mut buf).unwrap();
String::from_utf8_lossy(&buf[..n]).to_string()
}
fn parse_two_bulk(reply: &str) -> Option<(String,String)> {
let mut lines = reply.split("\r\n");
if lines.next()? != "*2" { return None; }
let _n = lines.next()?;
let a = lines.next()?.to_string();
let _m = lines.next()?;
let b = lines.next()?.to_string();
Some((a,b))
}
fn parse_bulk(reply: &str) -> Option<String> {
let mut lines = reply.split("\r\n");
let hdr = lines.next()?;
if !hdr.starts_with('$') { return None; }
Some(lines.next()?.to_string())
}
fn parse_simple(reply: &str) -> Option<String> {
let mut lines = reply.split("\r\n");
let hdr = lines.next()?;
if !hdr.starts_with('+') { return None; }
Some(hdr[1..].to_string())
}
fn main() {
let mut args = std::env::args().skip(1);
let host = args.next().unwrap_or_else(|| "127.0.0.1".into());
let port = args.next().unwrap_or_else(|| "6379".into());
let addr = format!("{host}:{port}");
println!("Connecting to {addr}...");
let mut s = TcpStream::connect(addr).expect("connect");
// Generate & persist X25519 enc keys under name "alice"
s.write_all(arr(&["age","keygen","alice"]).as_bytes()).unwrap();
let (_alice_recip, _alice_ident) = parse_two_bulk(&read_reply(&mut s)).expect("gen enc");
// Generate & persist Ed25519 signing key under name "signer"
s.write_all(arr(&["age","signkeygen","signer"]).as_bytes()).unwrap();
let (_verify, _secret) = parse_two_bulk(&read_reply(&mut s)).expect("gen sign");
// Encrypt by name
let msg = "hello from persistent keys";
s.write_all(arr(&["age","encryptname","alice", msg]).as_bytes()).unwrap();
let ct_b64 = parse_bulk(&read_reply(&mut s)).expect("ct b64");
println!("ciphertext b64: {}", ct_b64);
// Decrypt by name
s.write_all(arr(&["age","decryptname","alice", &ct_b64]).as_bytes()).unwrap();
let pt = parse_bulk(&read_reply(&mut s)).expect("pt");
assert_eq!(pt, msg);
println!("decrypted ok");
// Sign by name
s.write_all(arr(&["age","signname","signer", msg]).as_bytes()).unwrap();
let sig_b64 = parse_bulk(&read_reply(&mut s)).expect("sig b64");
// Verify by name
s.write_all(arr(&["age","verifyname","signer", msg, &sig_b64]).as_bytes()).unwrap();
let ok = parse_simple(&read_reply(&mut s)).expect("verify");
assert_eq!(ok, "1");
println!("signature verified");
// List names
s.write_all(arr(&["age","list"]).as_bytes()).unwrap();
let list = read_reply(&mut s);
println!("LIST -> {list}");
println!("✔ persistent AGE workflow complete.");
}

4
crates/herodb/lib.rs Normal file
View File

@@ -0,0 +1,4 @@
pub mod age; // NEW
pub mod cmd;
pub mod protocol;
pub mod server;

View File

@@ -2,7 +2,7 @@
use tokio::net::TcpListener; use tokio::net::TcpListener;
use redis_rs::server; use herodb::server;
use clap::Parser; use clap::Parser;
@@ -14,10 +14,22 @@ struct Args {
#[arg(long)] #[arg(long)]
dir: String, dir: String,
/// The port of the Redis server, default is 6379 if not specified /// The port of the Redis server, default is 6379 if not specified
#[arg(long)] #[arg(long)]
port: Option<u16>, port: Option<u16>,
/// Enable debug mode
#[arg(long)]
debug: bool,
/// Master encryption key for encrypted databases
#[arg(long)]
encryption_key: Option<String>,
/// Encrypt the database
#[arg(long)]
encrypt: bool,
} }
#[tokio::main] #[tokio::main]
@@ -33,14 +45,20 @@ async fn main() {
.unwrap(); .unwrap();
// new DB option // new DB option
let option = redis_rs::options::DBOption { let option = herodb::options::DBOption {
dir: args.dir, dir: args.dir,
port, port,
debug: args.debug,
encryption_key: args.encryption_key,
encrypt: args.encrypt,
}; };
// new server // new server
let server = server::Server::new(option).await; let server = server::Server::new(option).await;
// Add a small delay to ensure the port is ready
tokio::time::sleep(std::time::Duration::from_millis(100)).await;
// accept new connections // accept new connections
loop { loop {
let stream = listener.accept().await; let stream = listener.accept().await;

8
crates/herodb/options.rs Normal file
View File

@@ -0,0 +1,8 @@
#[derive(Clone)]
pub struct DBOption {
pub dir: String,
pub port: u16,
pub debug: bool,
pub encrypt: bool,
pub encryption_key: Option<String>, // Master encryption key
}

View File

@@ -8,6 +8,7 @@ pub enum Protocol {
BulkString(String), BulkString(String),
Null, Null,
Array(Vec<Protocol>), Array(Vec<Protocol>),
Error(String), // NEW
} }
impl fmt::Display for Protocol { impl fmt::Display for Protocol {
@@ -17,7 +18,7 @@ impl fmt::Display for Protocol {
} }
impl Protocol { impl Protocol {
pub fn from(protocol: &str) -> Result<(Self, usize), DBError> { pub fn from(protocol: &str) -> Result<(Self, &str), DBError> {
let ret = match protocol.chars().nth(0) { let ret = match protocol.chars().nth(0) {
Some('+') => Self::parse_simple_string_sfx(&protocol[1..]), Some('+') => Self::parse_simple_string_sfx(&protocol[1..]),
Some('$') => Self::parse_bulk_string_sfx(&protocol[1..]), Some('$') => Self::parse_bulk_string_sfx(&protocol[1..]),
@@ -27,10 +28,7 @@ impl Protocol {
protocol protocol
))), ))),
}; };
match ret { ret
Ok((p, s)) => Ok((p, s + 1)),
Err(e) => Err(e),
}
} }
pub fn from_vec(array: Vec<&str>) -> Self { pub fn from_vec(array: Vec<&str>) -> Self {
@@ -48,7 +46,7 @@ impl Protocol {
#[inline] #[inline]
pub fn err(msg: &str) -> Self { pub fn err(msg: &str) -> Self {
Protocol::SimpleString(msg.to_string()) Protocol::Error(msg.to_string())
} }
#[inline] #[inline]
@@ -72,28 +70,25 @@ impl Protocol {
Protocol::BulkString(s) => s.to_string(), Protocol::BulkString(s) => s.to_string(),
Protocol::Null => "".to_string(), Protocol::Null => "".to_string(),
Protocol::Array(s) => s.iter().map(|x| x.decode()).collect::<Vec<_>>().join(" "), Protocol::Array(s) => s.iter().map(|x| x.decode()).collect::<Vec<_>>().join(" "),
Protocol::Error(s) => s.to_string(),
} }
} }
pub fn encode(&self) -> String { pub fn encode(&self) -> String {
match self { match self {
Protocol::SimpleString(s) => format!("+{}\r\n", s), Protocol::SimpleString(s) => format!("+{}\r\n", s),
Protocol::BulkString(s) => format!("${}\r\n{}\r\n", s.len(), s), Protocol::BulkString(s) => format!("${}\r\n{}\r\n", s.len(), s),
Protocol::Array(ss) => { Protocol::Array(ss) => {
format!("*{}\r\n", ss.len()) format!("*{}\r\n", ss.len()) + &ss.iter().map(|x| x.encode()).collect::<String>()
+ ss.iter()
.map(|x| x.encode())
.collect::<Vec<_>>()
.join("")
.as_str()
} }
Protocol::Null => "$-1\r\n".to_string(), Protocol::Null => "$-1\r\n".to_string(),
Protocol::Error(s) => format!("-{}\r\n", s), // proper RESP error
} }
} }
fn parse_simple_string_sfx(protocol: &str) -> Result<(Self, usize), DBError> { fn parse_simple_string_sfx(protocol: &str) -> Result<(Self, &str), DBError> {
match protocol.find("\r\n") { match protocol.find("\r\n") {
Some(x) => Ok((Self::SimpleString(protocol[..x].to_string()), x + 2)), Some(x) => Ok((Self::SimpleString(protocol[..x].to_string()), &protocol[x + 2..])),
_ => Err(DBError(format!( _ => Err(DBError(format!(
"[new simple string] unsupported protocol: {:?}", "[new simple string] unsupported protocol: {:?}",
protocol protocol
@@ -101,27 +96,20 @@ impl Protocol {
} }
} }
fn parse_bulk_string_sfx(protocol: &str) -> Result<(Self, usize), DBError> { fn parse_bulk_string_sfx(protocol: &str) -> Result<(Self, &str), DBError> {
if let Some(len) = protocol.find("\r\n") { if let Some(len_end) = protocol.find("\r\n") {
let size = Self::parse_usize(&protocol[..len])?; let size = Self::parse_usize(&protocol[..len_end])?;
if let Some(data_len) = protocol[len + 2..].find("\r\n") { let data_start = len_end + 2;
let s = Self::parse_string(&protocol[len + 2..len + 2 + data_len])?; let data_end = data_start + size;
if size != s.len() { let s = Self::parse_string(&protocol[data_start..data_end])?;
Err(DBError(format!(
"[new bulk string] unmatched string length in prototocl {:?}", if protocol.len() < data_end + 2 || &protocol[data_end..data_end+2] != "\r\n" {
protocol, Err(DBError(format!(
))) "[new bulk string] unmatched string length in prototocl {:?}",
} else { protocol,
Ok((
Protocol::BulkString(s.to_lowercase()),
len + 2 + data_len + 2,
))
}
} else {
Err(DBError(format!(
"[new bulk string] unsupported protocol: {:?}",
protocol
))) )))
} else {
Ok((Protocol::BulkString(s), &protocol[data_end + 2..]))
} }
} else { } else {
Err(DBError(format!( Err(DBError(format!(
@@ -131,46 +119,41 @@ impl Protocol {
} }
} }
fn parse_array_sfx(s: &str) -> Result<(Self, usize), DBError> { fn parse_array_sfx(s: &str) -> Result<(Self, &str), DBError> {
let mut offset = 0; if let Some(len_end) = s.find("\r\n") {
match s.find("\r\n") { let array_len = s[..len_end].parse::<usize>()?;
Some(x) => { let mut remaining = &s[len_end + 2..];
let array_len = s[..x].parse::<usize>()?; let mut vec = vec![];
offset += x + 2; for _ in 0..array_len {
let mut vec = vec![]; let (p, rem) = Protocol::from(remaining)?;
for _ in 0..array_len { vec.push(p);
match Protocol::from(&s[offset..]) { remaining = rem;
Ok((p, len)) => {
offset += len;
vec.push(p);
}
Err(e) => {
return Err(e);
}
}
}
Ok((Protocol::Array(vec), offset))
} }
_ => Err(DBError(format!( Ok((Protocol::Array(vec), remaining))
} else {
Err(DBError(format!(
"[new array] unsupported protocol: {:?}", "[new array] unsupported protocol: {:?}",
s s
))), )))
} }
} }
fn parse_usize(protocol: &str) -> Result<usize, DBError> { fn parse_usize(protocol: &str) -> Result<usize, DBError> {
match protocol.len() { if protocol.is_empty() {
0 => Err(DBError(format!("parse usize error: {:?}", protocol))), Err(DBError("Cannot parse usize from empty string".to_string()))
_ => Ok(protocol } else {
protocol
.parse::<usize>() .parse::<usize>()
.map_err(|_| DBError(format!("parse usize error: {}", protocol)))?), .map_err(|_| DBError(format!("Failed to parse usize from: {}", protocol)))
} }
} }
fn parse_string(protocol: &str) -> Result<String, DBError> { fn parse_string(protocol: &str) -> Result<String, DBError> {
match protocol.len() { if protocol.is_empty() {
0 => Err(DBError(format!("parse usize error: {:?}", protocol))), // Allow empty strings, but handle appropriately
_ => Ok(protocol.to_string()), Ok("".to_string())
} else {
Ok(protocol.to_string())
} }
} }
} }

136
crates/herodb/server.rs Normal file
View File

@@ -0,0 +1,136 @@
use core::str;
use std::collections::HashMap;
use std::sync::Arc;
use tokio::io::AsyncReadExt;
use tokio::io::AsyncWriteExt;
use crate::cmd::Cmd;
use crate::error::DBError;
use crate::options;
use crate::protocol::Protocol;
use crate::storage::Storage;
#[derive(Clone)]
pub struct Server {
pub db_cache: std::sync::Arc<std::sync::RwLock<HashMap<u64, Arc<Storage>>>>,
pub option: options::DBOption,
pub client_name: Option<String>,
pub selected_db: u64, // Changed from usize to u64
pub queued_cmd: Option<Vec<(Cmd, Protocol)>>,
}
impl Server {
pub async fn new(option: options::DBOption) -> Self {
Server {
db_cache: Arc::new(std::sync::RwLock::new(HashMap::new())),
option,
client_name: None,
selected_db: 0,
queued_cmd: None,
}
}
pub fn current_storage(&self) -> Result<Arc<libdbstorage::Storage>, libdbstorage::DBError> {
let mut cache = self.db_cache.write().unwrap();
if let Some(storage) = cache.get(&self.selected_db) {
return Ok(storage.clone());
}
// Create new database file
let db_file_path = std::path::PathBuf::from(self.option.dir.clone())
.join(format!("{}.db", self.selected_db));
// Ensure the directory exists before creating the database file
if let Some(parent_dir) = db_file_path.parent() {
std::fs::create_dir_all(parent_dir).map_err(|e| {
DBError(format!("Failed to create directory {}: {}", parent_dir.display(), e))
})?;
}
println!("Creating new db file: {}", db_file_path.display());
let storage = Arc::new(Storage::new(
db_file_path,
self.should_encrypt_db(self.selected_db),
self.option.encryption_key.as_deref()
)?);
cache.insert(self.selected_db, storage.clone());
Ok(storage)
}
fn should_encrypt_db(&self, db_index: u64) -> bool {
// DB 0-9 are non-encrypted, DB 10+ are encrypted
self.option.encrypt && db_index >= 10
}
pub async fn handle(
&mut self,
mut stream: tokio::net::TcpStream,
) -> Result<(), DBError> {
let mut buf = [0; 512];
loop {
let len = match stream.read(&mut buf).await {
Ok(0) => {
println!("[handle] connection closed");
return Ok(());
}
Ok(len) => len,
Err(e) => {
println!("[handle] read error: {:?}", e);
return Err(e.into());
}
};
let mut s = str::from_utf8(&buf[..len])?;
while !s.is_empty() {
let (cmd, protocol, remaining) = match Cmd::from(s) {
Ok((cmd, protocol, remaining)) => (cmd, protocol, remaining),
Err(e) => {
println!("\x1b[31;1mprotocol error: {:?}\x1b[0m", e);
(Cmd::Unknow("protocol_error".to_string()), Protocol::err(&format!("protocol error: {}", e.0)), "")
}
};
s = remaining;
if self.option.debug {
println!("\x1b[34;1mgot command: {:?}, protocol: {:?}\x1b[0m", cmd, protocol);
} else {
println!("got command: {:?}, protocol: {:?}", cmd, protocol);
}
// Check if this is a QUIT command before processing
let is_quit = matches!(cmd, Cmd::Quit);
let res = match cmd.run(self).await {
Ok(p) => p,
Err(e) => {
if self.option.debug {
eprintln!("[run error] {:?}", e);
}
Protocol::err(&format!("ERR {}", e.0))
}
};
if self.option.debug {
println!("\x1b[34;1mqueued cmd {:?}\x1b[0m", self.queued_cmd);
println!("\x1b[32;1mgoing to send response {}\x1b[0m", res.encode());
} else {
print!("queued cmd {:?}", self.queued_cmd);
println!("going to send response {}", res.encode());
}
_ = stream.write(res.encode().as_bytes()).await?;
// If this was a QUIT command, close the connection
if is_quit {
println!("[handle] QUIT command received, closing connection");
return Ok(());
}
}
}
}
}

View File

@@ -0,0 +1 @@
fn main() {}

View File

@@ -0,0 +1,62 @@
use herodb::{server::Server, options::DBOption};
use std::time::Duration;
use tokio::io::{AsyncReadExt, AsyncWriteExt};
use tokio::net::TcpStream;
use tokio::time::sleep;
// Helper function to send command and get response
async fn send_command(stream: &mut TcpStream, command: &str) -> String {
stream.write_all(command.as_bytes()).await.unwrap();
let mut buffer = [0; 1024];
let n = stream.read(&mut buffer).await.unwrap();
String::from_utf8_lossy(&buffer[..n]).to_string()
}
#[tokio::test]
async fn debug_hset_simple() {
// Clean up any existing test database
let test_dir = "/tmp/herodb_debug_hset";
let _ = std::fs::remove_dir_all(test_dir);
std::fs::create_dir_all(test_dir).unwrap();
let port = 16500;
let option = DBOption {
dir: test_dir.to_string(),
port,
debug: false,
encrypt: false,
encryption_key: None,
};
let mut server = Server::new(option).await;
// Start server in background
tokio::spawn(async move {
let listener = tokio::net::TcpListener::bind(format!("127.0.0.1:{}", port))
.await
.unwrap();
loop {
if let Ok((stream, _)) = listener.accept().await {
let _ = server.handle(stream).await;
}
}
});
sleep(Duration::from_millis(200)).await;
let mut stream = TcpStream::connect(format!("127.0.0.1:{}", port)).await.unwrap();
// Test simple HSET
println!("Testing HSET...");
let response = send_command(&mut stream, "*4\r\n$4\r\nHSET\r\n$4\r\nhash\r\n$6\r\nfield1\r\n$6\r\nvalue1\r\n").await;
println!("HSET response: {}", response);
assert!(response.contains("1"), "Expected '1' but got: {}", response);
// Test HGET
println!("Testing HGET...");
let response = send_command(&mut stream, "*3\r\n$4\r\nHGET\r\n$4\r\nhash\r\n$6\r\nfield1\r\n").await;
println!("HGET response: {}", response);
assert!(response.contains("value1"), "Expected 'value1' but got: {}", response);
}

View File

@@ -0,0 +1,56 @@
use herodb::{server::Server, options::DBOption};
use std::time::Duration;
use tokio::io::{AsyncReadExt, AsyncWriteExt};
use tokio::net::TcpStream;
use tokio::time::sleep;
#[tokio::test]
async fn debug_hset_return_value() {
let test_dir = "/tmp/herodb_debug_hset_return";
// Clean up any existing test data
let _ = std::fs::remove_dir_all(&test_dir);
std::fs::create_dir_all(&test_dir).unwrap();
let option = DBOption {
dir: test_dir.to_string(),
port: 16390,
debug: false,
encrypt: false,
encryption_key: None,
};
let mut server = Server::new(option).await;
// Start server in background
tokio::spawn(async move {
let listener = tokio::net::TcpListener::bind("127.0.0.1:16390")
.await
.unwrap();
loop {
if let Ok((stream, _)) = listener.accept().await {
let _ = server.handle(stream).await;
}
}
});
sleep(Duration::from_millis(200)).await;
// Connect and test HSET
let mut stream = TcpStream::connect("127.0.0.1:16390").await.unwrap();
// Send HSET command
let cmd = "*4\r\n$4\r\nHSET\r\n$4\r\nhash\r\n$6\r\nfield1\r\n$6\r\nvalue1\r\n";
stream.write_all(cmd.as_bytes()).await.unwrap();
let mut buffer = [0; 1024];
let n = stream.read(&mut buffer).await.unwrap();
let response = String::from_utf8_lossy(&buffer[..n]);
println!("HSET response: {}", response);
println!("Response bytes: {:?}", &buffer[..n]);
// Check if response contains "1"
assert!(response.contains("1"), "Expected response to contain '1', got: {}", response);
}

View File

@@ -0,0 +1,35 @@
use herodb::protocol::Protocol;
use herodb::cmd::Cmd;
#[test]
fn test_protocol_parsing() {
// Test TYPE command parsing
let type_cmd = "*2\r\n$4\r\nTYPE\r\n$7\r\nnoexist\r\n";
println!("Parsing TYPE command: {}", type_cmd.replace("\r\n", "\\r\\n"));
match Protocol::from(type_cmd) {
Ok((protocol, _)) => {
println!("Protocol parsed successfully: {:?}", protocol);
match Cmd::from(type_cmd) {
Ok((cmd, _, _)) => println!("Command parsed successfully: {:?}", cmd),
Err(e) => println!("Command parsing failed: {:?}", e),
}
}
Err(e) => println!("Protocol parsing failed: {:?}", e),
}
// Test HEXISTS command parsing
let hexists_cmd = "*3\r\n$7\r\nHEXISTS\r\n$4\r\nhash\r\n$7\r\nnoexist\r\n";
println!("\nParsing HEXISTS command: {}", hexists_cmd.replace("\r\n", "\\r\\n"));
match Protocol::from(hexists_cmd) {
Ok((protocol, _)) => {
println!("Protocol parsed successfully: {:?}", protocol);
match Cmd::from(hexists_cmd) {
Ok((cmd, _, _)) => println!("Command parsed successfully: {:?}", cmd),
Err(e) => println!("Command parsing failed: {:?}", e),
}
}
Err(e) => println!("Protocol parsing failed: {:?}", e),
}
}

View File

@@ -0,0 +1,317 @@
use redis::{Client, Commands, Connection, RedisResult};
use std::process::{Child, Command};
use std::time::Duration;
use tokio::time::sleep;
// Helper function to get Redis connection, retrying until successful
fn get_redis_connection(port: u16) -> Connection {
let connection_info = format!("redis://127.0.0.1:{}", port);
let client = Client::open(connection_info).unwrap();
let mut attempts = 0;
loop {
match client.get_connection() {
Ok(mut conn) => {
if redis::cmd("PING").query::<String>(&mut conn).is_ok() {
return conn;
}
}
Err(e) => {
if attempts >= 20 {
panic!(
"Failed to connect to Redis server after 20 attempts: {}",
e
);
}
}
}
attempts += 1;
std::thread::sleep(Duration::from_millis(100));
}
}
// A guard to ensure the server process is killed when it goes out of scope
struct ServerProcessGuard {
process: Child,
test_dir: String,
}
impl Drop for ServerProcessGuard {
fn drop(&mut self) {
println!("Killing server process (pid: {})...", self.process.id());
if let Err(e) = self.process.kill() {
eprintln!("Failed to kill server process: {}", e);
}
match self.process.wait() {
Ok(status) => println!("Server process exited with: {}", status),
Err(e) => eprintln!("Failed to wait on server process: {}", e),
}
// Clean up the specific test directory
println!("Cleaning up test directory: {}", self.test_dir);
if let Err(e) = std::fs::remove_dir_all(&self.test_dir) {
eprintln!("Failed to clean up test directory: {}", e);
}
}
}
// Helper to set up the server and return a connection
fn setup_server() -> (ServerProcessGuard, u16) {
use std::sync::atomic::{AtomicU16, Ordering};
static PORT_COUNTER: AtomicU16 = AtomicU16::new(16400);
let port = PORT_COUNTER.fetch_add(1, Ordering::SeqCst);
let test_dir = format!("/tmp/herodb_test_{}", port);
// Clean up previous test data
if std::path::Path::new(&test_dir).exists() {
let _ = std::fs::remove_dir_all(&test_dir);
}
std::fs::create_dir_all(&test_dir).unwrap();
// Start the server in a subprocess
let child = Command::new("cargo")
.args(&[
"run",
"--",
"--dir",
&test_dir,
"--port",
&port.to_string(),
"--debug",
])
.spawn()
.expect("Failed to start server process");
// Create a new guard that also owns the test directory path
let guard = ServerProcessGuard {
process: child,
test_dir,
};
// Give the server a moment to start
std::thread::sleep(Duration::from_millis(500));
(guard, port)
}
async fn cleanup_keys(conn: &mut Connection) {
let keys: Vec<String> = redis::cmd("KEYS").arg("*").query(conn).unwrap();
if !keys.is_empty() {
for key in keys {
let _: () = redis::cmd("DEL").arg(key).query(conn).unwrap();
}
}
}
#[tokio::test]
async fn all_tests() {
let (_server_guard, port) = setup_server();
let mut conn = get_redis_connection(port);
// Run all tests using the same connection
test_basic_ping(&mut conn).await;
test_string_operations(&mut conn).await;
test_incr_operations(&mut conn).await;
test_hash_operations(&mut conn).await;
test_expiration(&mut conn).await;
test_scan_operations(&mut conn).await;
test_scan_with_count(&mut conn).await;
test_hscan_operations(&mut conn).await;
test_transaction_operations(&mut conn).await;
test_discard_transaction(&mut conn).await;
test_type_command(&mut conn).await;
test_info_command(&mut conn).await;
}
async fn test_basic_ping(conn: &mut Connection) {
cleanup_keys(conn).await;
let result: String = redis::cmd("PING").query(conn).unwrap();
assert_eq!(result, "PONG");
cleanup_keys(conn).await;
}
async fn test_string_operations(conn: &mut Connection) {
cleanup_keys(conn).await;
let _: () = conn.set("key", "value").unwrap();
let result: String = conn.get("key").unwrap();
assert_eq!(result, "value");
let result: Option<String> = conn.get("noexist").unwrap();
assert_eq!(result, None);
let deleted: i32 = conn.del("key").unwrap();
assert_eq!(deleted, 1);
let result: Option<String> = conn.get("key").unwrap();
assert_eq!(result, None);
cleanup_keys(conn).await;
}
async fn test_incr_operations(conn: &mut Connection) {
cleanup_keys(conn).await;
let result: i32 = redis::cmd("INCR").arg("counter").query(conn).unwrap();
assert_eq!(result, 1);
let result: i32 = redis::cmd("INCR").arg("counter").query(conn).unwrap();
assert_eq!(result, 2);
let _: () = conn.set("string", "hello").unwrap();
let result: RedisResult<i32> = redis::cmd("INCR").arg("string").query(conn);
assert!(result.is_err());
cleanup_keys(conn).await;
}
async fn test_hash_operations(conn: &mut Connection) {
cleanup_keys(conn).await;
let result: i32 = conn.hset("hash", "field1", "value1").unwrap();
assert_eq!(result, 1);
let result: String = conn.hget("hash", "field1").unwrap();
assert_eq!(result, "value1");
let _: () = conn.hset("hash", "field2", "value2").unwrap();
let _: () = conn.hset("hash", "field3", "value3").unwrap();
let result: std::collections::HashMap<String, String> = conn.hgetall("hash").unwrap();
assert_eq!(result.len(), 3);
assert_eq!(result.get("field1").unwrap(), "value1");
assert_eq!(result.get("field2").unwrap(), "value2");
assert_eq!(result.get("field3").unwrap(), "value3");
let result: i32 = conn.hlen("hash").unwrap();
assert_eq!(result, 3);
let result: bool = conn.hexists("hash", "field1").unwrap();
assert_eq!(result, true);
let result: bool = conn.hexists("hash", "noexist").unwrap();
assert_eq!(result, false);
let result: i32 = conn.hdel("hash", "field1").unwrap();
assert_eq!(result, 1);
let mut result: Vec<String> = conn.hkeys("hash").unwrap();
result.sort();
assert_eq!(result, vec!["field2", "field3"]);
let mut result: Vec<String> = conn.hvals("hash").unwrap();
result.sort();
assert_eq!(result, vec!["value2", "value3"]);
cleanup_keys(conn).await;
}
async fn test_expiration(conn: &mut Connection) {
cleanup_keys(conn).await;
let _: () = conn.set_ex("expkey", "value", 1).unwrap();
let result: i32 = conn.ttl("expkey").unwrap();
assert!(result == 1 || result == 0);
let result: bool = conn.exists("expkey").unwrap();
assert_eq!(result, true);
sleep(Duration::from_millis(1100)).await;
let result: Option<String> = conn.get("expkey").unwrap();
assert_eq!(result, None);
let result: i32 = conn.ttl("expkey").unwrap();
assert_eq!(result, -2);
let result: bool = conn.exists("expkey").unwrap();
assert_eq!(result, false);
cleanup_keys(conn).await;
}
async fn test_scan_operations(conn: &mut Connection) {
cleanup_keys(conn).await;
for i in 0..5 {
let _: () = conn.set(format!("key{}", i), format!("value{}", i)).unwrap();
}
let result: (u64, Vec<String>) = redis::cmd("SCAN")
.arg(0)
.arg("MATCH")
.arg("key*")
.arg("COUNT")
.arg(10)
.query(conn)
.unwrap();
let (cursor, keys) = result;
assert_eq!(cursor, 0);
assert_eq!(keys.len(), 5);
cleanup_keys(conn).await;
}
async fn test_scan_with_count(conn: &mut Connection) {
cleanup_keys(conn).await;
for i in 0..15 {
let _: () = conn.set(format!("scan_key{}", i), i).unwrap();
}
let mut cursor = 0;
let mut all_keys = std::collections::HashSet::new();
loop {
let (next_cursor, keys): (u64, Vec<String>) = redis::cmd("SCAN")
.arg(cursor)
.arg("MATCH")
.arg("scan_key*")
.arg("COUNT")
.arg(5)
.query(conn)
.unwrap();
for key in keys {
all_keys.insert(key);
}
cursor = next_cursor;
if cursor == 0 {
break;
}
}
assert_eq!(all_keys.len(), 15);
cleanup_keys(conn).await;
}
async fn test_hscan_operations(conn: &mut Connection) {
cleanup_keys(conn).await;
for i in 0..3 {
let _: () = conn.hset("testhash", format!("field{}", i), format!("value{}", i)).unwrap();
}
let result: (u64, Vec<String>) = redis::cmd("HSCAN")
.arg("testhash")
.arg(0)
.arg("MATCH")
.arg("*")
.arg("COUNT")
.arg(10)
.query(conn)
.unwrap();
let (cursor, fields) = result;
assert_eq!(cursor, 0);
assert_eq!(fields.len(), 6);
cleanup_keys(conn).await;
}
async fn test_transaction_operations(conn: &mut Connection) {
cleanup_keys(conn).await;
let _: () = redis::cmd("MULTI").query(conn).unwrap();
let _: () = redis::cmd("SET").arg("key1").arg("value1").query(conn).unwrap();
let _: () = redis::cmd("SET").arg("key2").arg("value2").query(conn).unwrap();
let _: Vec<String> = redis::cmd("EXEC").query(conn).unwrap();
let result: String = conn.get("key1").unwrap();
assert_eq!(result, "value1");
let result: String = conn.get("key2").unwrap();
assert_eq!(result, "value2");
cleanup_keys(conn).await;
}
async fn test_discard_transaction(conn: &mut Connection) {
cleanup_keys(conn).await;
let _: () = redis::cmd("MULTI").query(conn).unwrap();
let _: () = redis::cmd("SET").arg("discard").arg("value").query(conn).unwrap();
let _: () = redis::cmd("DISCARD").query(conn).unwrap();
let result: Option<String> = conn.get("discard").unwrap();
assert_eq!(result, None);
cleanup_keys(conn).await;
}
async fn test_type_command(conn: &mut Connection) {
cleanup_keys(conn).await;
let _: () = conn.set("string", "value").unwrap();
let result: String = redis::cmd("TYPE").arg("string").query(conn).unwrap();
assert_eq!(result, "string");
let _: () = conn.hset("hash", "field", "value").unwrap();
let result: String = redis::cmd("TYPE").arg("hash").query(conn).unwrap();
assert_eq!(result, "hash");
let result: String = redis::cmd("TYPE").arg("noexist").query(conn).unwrap();
assert_eq!(result, "none");
cleanup_keys(conn).await;
}
async fn test_info_command(conn: &mut Connection) {
cleanup_keys(conn).await;
let result: String = redis::cmd("INFO").query(conn).unwrap();
assert!(result.contains("redis_version"));
let result: String = redis::cmd("INFO").arg("replication").query(conn).unwrap();
assert!(result.contains("role:master"));
cleanup_keys(conn).await;
}

View File

@@ -0,0 +1,608 @@
use herodb::{server::Server, options::DBOption};
use std::time::Duration;
use tokio::io::{AsyncReadExt, AsyncWriteExt};
use tokio::net::TcpStream;
use tokio::time::sleep;
// Helper function to start a test server
async fn start_test_server(test_name: &str) -> (Server, u16) {
use std::sync::atomic::{AtomicU16, Ordering};
static PORT_COUNTER: AtomicU16 = AtomicU16::new(16379);
let port = PORT_COUNTER.fetch_add(1, Ordering::SeqCst);
let test_dir = format!("/tmp/herodb_test_{}", test_name);
// Clean up and create test directory
let _ = std::fs::remove_dir_all(&test_dir);
std::fs::create_dir_all(&test_dir).unwrap();
let option = DBOption {
dir: test_dir,
port,
debug: true,
encrypt: false,
encryption_key: None,
};
let server = Server::new(option).await;
(server, port)
}
// Helper function to connect to the test server
async fn connect_to_server(port: u16) -> TcpStream {
let mut attempts = 0;
loop {
match TcpStream::connect(format!("127.0.0.1:{}", port)).await {
Ok(stream) => return stream,
Err(_) if attempts < 10 => {
attempts += 1;
sleep(Duration::from_millis(100)).await;
}
Err(e) => panic!("Failed to connect to test server: {}", e),
}
}
}
// Helper function to send command and get response
async fn send_command(stream: &mut TcpStream, command: &str) -> String {
stream.write_all(command.as_bytes()).await.unwrap();
let mut buffer = [0; 1024];
let n = stream.read(&mut buffer).await.unwrap();
String::from_utf8_lossy(&buffer[..n]).to_string()
}
#[tokio::test]
async fn test_basic_ping() {
let (mut server, port) = start_test_server("ping").await;
// Start server in background
tokio::spawn(async move {
let listener = tokio::net::TcpListener::bind(format!("127.0.0.1:{}", port))
.await
.unwrap();
loop {
if let Ok((stream, _)) = listener.accept().await {
let _ = server.handle(stream).await;
}
}
});
sleep(Duration::from_millis(100)).await;
let mut stream = connect_to_server(port).await;
let response = send_command(&mut stream, "*1\r\n$4\r\nPING\r\n").await;
assert!(response.contains("PONG"));
}
#[tokio::test]
async fn test_string_operations() {
let (mut server, port) = start_test_server("string").await;
// Start server in background
tokio::spawn(async move {
let listener = tokio::net::TcpListener::bind(format!("127.0.0.1:{}", port))
.await
.unwrap();
loop {
if let Ok((stream, _)) = listener.accept().await {
let _ = server.handle(stream).await;
}
}
});
sleep(Duration::from_millis(100)).await;
let mut stream = connect_to_server(port).await;
// Test SET
let response = send_command(&mut stream, "*3\r\n$3\r\nSET\r\n$3\r\nkey\r\n$5\r\nvalue\r\n").await;
assert!(response.contains("OK"));
// Test GET
let response = send_command(&mut stream, "*2\r\n$3\r\nGET\r\n$3\r\nkey\r\n").await;
assert!(response.contains("value"));
// Test GET non-existent key
let response = send_command(&mut stream, "*2\r\n$3\r\nGET\r\n$7\r\nnoexist\r\n").await;
assert!(response.contains("$-1")); // NULL response
// Test DEL
let response = send_command(&mut stream, "*2\r\n$3\r\nDEL\r\n$3\r\nkey\r\n").await;
assert!(response.contains("1"));
// Test GET after DEL
let response = send_command(&mut stream, "*2\r\n$3\r\nGET\r\n$3\r\nkey\r\n").await;
assert!(response.contains("$-1")); // NULL response
}
#[tokio::test]
async fn test_incr_operations() {
let (mut server, port) = start_test_server("incr").await;
tokio::spawn(async move {
let listener = tokio::net::TcpListener::bind(format!("127.0.0.1:{}", port))
.await
.unwrap();
loop {
if let Ok((stream, _)) = listener.accept().await {
let _ = server.handle(stream).await;
}
}
});
sleep(Duration::from_millis(100)).await;
let mut stream = connect_to_server(port).await;
// Test INCR on non-existent key
let response = send_command(&mut stream, "*2\r\n$4\r\nINCR\r\n$7\r\ncounter\r\n").await;
assert!(response.contains("1"));
// Test INCR on existing key
let response = send_command(&mut stream, "*2\r\n$4\r\nINCR\r\n$7\r\ncounter\r\n").await;
assert!(response.contains("2"));
// Test INCR on string value (should fail)
send_command(&mut stream, "*3\r\n$3\r\nSET\r\n$6\r\nstring\r\n$5\r\nhello\r\n").await;
let response = send_command(&mut stream, "*2\r\n$4\r\nINCR\r\n$6\r\nstring\r\n").await;
assert!(response.contains("ERR"));
}
#[tokio::test]
async fn test_hash_operations() {
let (mut server, port) = start_test_server("hash").await;
tokio::spawn(async move {
let listener = tokio::net::TcpListener::bind(format!("127.0.0.1:{}", port))
.await
.unwrap();
loop {
if let Ok((stream, _)) = listener.accept().await {
let _ = server.handle(stream).await;
}
}
});
sleep(Duration::from_millis(100)).await;
let mut stream = connect_to_server(port).await;
// Test HSET
let response = send_command(&mut stream, "*4\r\n$4\r\nHSET\r\n$4\r\nhash\r\n$6\r\nfield1\r\n$6\r\nvalue1\r\n").await;
assert!(response.contains("1")); // 1 new field
// Test HGET
let response = send_command(&mut stream, "*3\r\n$4\r\nHGET\r\n$4\r\nhash\r\n$6\r\nfield1\r\n").await;
assert!(response.contains("value1"));
// Test HSET multiple fields
let response = send_command(&mut stream, "*6\r\n$4\r\nHSET\r\n$4\r\nhash\r\n$6\r\nfield2\r\n$6\r\nvalue2\r\n$6\r\nfield3\r\n$6\r\nvalue3\r\n").await;
assert!(response.contains("2")); // 2 new fields
// Test HGETALL
let response = send_command(&mut stream, "*2\r\n$7\r\nHGETALL\r\n$4\r\nhash\r\n").await;
assert!(response.contains("field1"));
assert!(response.contains("value1"));
assert!(response.contains("field2"));
assert!(response.contains("value2"));
// Test HLEN
let response = send_command(&mut stream, "*2\r\n$4\r\nHLEN\r\n$4\r\nhash\r\n").await;
assert!(response.contains("3"));
// Test HEXISTS
let response = send_command(&mut stream, "*3\r\n$7\r\nHEXISTS\r\n$4\r\nhash\r\n$6\r\nfield1\r\n").await;
assert!(response.contains("1"));
let response = send_command(&mut stream, "*3\r\n$7\r\nHEXISTS\r\n$4\r\nhash\r\n$7\r\nnoexist\r\n").await;
assert!(response.contains("0"));
// Test HDEL
let response = send_command(&mut stream, "*3\r\n$4\r\nHDEL\r\n$4\r\nhash\r\n$6\r\nfield1\r\n").await;
assert!(response.contains("1"));
// Test HKEYS
let response = send_command(&mut stream, "*2\r\n$5\r\nHKEYS\r\n$4\r\nhash\r\n").await;
assert!(response.contains("field2"));
assert!(response.contains("field3"));
assert!(!response.contains("field1")); // Should be deleted
// Test HVALS
let response = send_command(&mut stream, "*2\r\n$5\r\nHVALS\r\n$4\r\nhash\r\n").await;
assert!(response.contains("value2"));
assert!(response.contains("value3"));
}
#[tokio::test]
async fn test_expiration() {
let (mut server, port) = start_test_server("expiration").await;
tokio::spawn(async move {
let listener = tokio::net::TcpListener::bind(format!("127.0.0.1:{}", port))
.await
.unwrap();
loop {
if let Ok((stream, _)) = listener.accept().await {
let _ = server.handle(stream).await;
}
}
});
sleep(Duration::from_millis(100)).await;
let mut stream = connect_to_server(port).await;
// Test SETEX (expire in 1 second)
let response = send_command(&mut stream, "*5\r\n$3\r\nSET\r\n$6\r\nexpkey\r\n$5\r\nvalue\r\n$2\r\nEX\r\n$1\r\n1\r\n").await;
assert!(response.contains("OK"));
// Test TTL
let response = send_command(&mut stream, "*2\r\n$3\r\nTTL\r\n$6\r\nexpkey\r\n").await;
assert!(response.contains("1") || response.contains("0")); // Should be 1 or 0 seconds
// Test EXISTS
let response = send_command(&mut stream, "*2\r\n$6\r\nEXISTS\r\n$6\r\nexpkey\r\n").await;
assert!(response.contains("1"));
// Wait for expiration
sleep(Duration::from_millis(1100)).await;
// Test GET after expiration
let response = send_command(&mut stream, "*2\r\n$3\r\nGET\r\n$6\r\nexpkey\r\n").await;
assert!(response.contains("$-1")); // Should be NULL
// Test TTL after expiration
let response = send_command(&mut stream, "*2\r\n$3\r\nTTL\r\n$6\r\nexpkey\r\n").await;
assert!(response.contains("-2")); // Key doesn't exist
// Test EXISTS after expiration
let response = send_command(&mut stream, "*2\r\n$6\r\nEXISTS\r\n$6\r\nexpkey\r\n").await;
assert!(response.contains("0"));
}
#[tokio::test]
async fn test_scan_operations() {
let (mut server, port) = start_test_server("scan").await;
tokio::spawn(async move {
let listener = tokio::net::TcpListener::bind(format!("127.0.0.1:{}", port))
.await
.unwrap();
loop {
if let Ok((stream, _)) = listener.accept().await {
let _ = server.handle(stream).await;
}
}
});
sleep(Duration::from_millis(100)).await;
let mut stream = connect_to_server(port).await;
// Set up test data
for i in 0..5 {
let cmd = format!("*3\r\n$3\r\nSET\r\n$4\r\nkey{}\r\n$6\r\nvalue{}\r\n", i, i);
send_command(&mut stream, &cmd).await;
}
// Test SCAN
let response = send_command(&mut stream, "*6\r\n$4\r\nSCAN\r\n$1\r\n0\r\n$5\r\nMATCH\r\n$1\r\n*\r\n$5\r\nCOUNT\r\n$2\r\n10\r\n").await;
assert!(response.contains("key"));
// Test KEYS
let response = send_command(&mut stream, "*2\r\n$4\r\nKEYS\r\n$1\r\n*\r\n").await;
assert!(response.contains("key0"));
assert!(response.contains("key1"));
}
#[tokio::test]
async fn test_hscan_operations() {
let (mut server, port) = start_test_server("hscan").await;
tokio::spawn(async move {
let listener = tokio::net::TcpListener::bind(format!("127.0.0.1:{}", port))
.await
.unwrap();
loop {
if let Ok((stream, _)) = listener.accept().await {
let _ = server.handle(stream).await;
}
}
});
sleep(Duration::from_millis(100)).await;
let mut stream = connect_to_server(port).await;
// Set up hash data
for i in 0..3 {
let cmd = format!("*4\r\n$4\r\nHSET\r\n$8\r\ntesthash\r\n$6\r\nfield{}\r\n$6\r\nvalue{}\r\n", i, i);
send_command(&mut stream, &cmd).await;
}
// Test HSCAN
let response = send_command(&mut stream, "*7\r\n$5\r\nHSCAN\r\n$8\r\ntesthash\r\n$1\r\n0\r\n$5\r\nMATCH\r\n$1\r\n*\r\n$5\r\nCOUNT\r\n$2\r\n10\r\n").await;
assert!(response.contains("field"));
assert!(response.contains("value"));
}
#[tokio::test]
async fn test_transaction_operations() {
let (mut server, port) = start_test_server("transaction").await;
tokio::spawn(async move {
let listener = tokio::net::TcpListener::bind(format!("127.0.0.1:{}", port))
.await
.unwrap();
loop {
if let Ok((stream, _)) = listener.accept().await {
let _ = server.handle(stream).await;
}
}
});
sleep(Duration::from_millis(100)).await;
let mut stream = connect_to_server(port).await;
// Test MULTI
let response = send_command(&mut stream, "*1\r\n$5\r\nMULTI\r\n").await;
assert!(response.contains("OK"));
// Test queued commands
let response = send_command(&mut stream, "*3\r\n$3\r\nSET\r\n$4\r\nkey1\r\n$6\r\nvalue1\r\n").await;
assert!(response.contains("QUEUED"));
let response = send_command(&mut stream, "*3\r\n$3\r\nSET\r\n$4\r\nkey2\r\n$6\r\nvalue2\r\n").await;
assert!(response.contains("QUEUED"));
// Test EXEC
let response = send_command(&mut stream, "*1\r\n$4\r\nEXEC\r\n").await;
assert!(response.contains("OK")); // Should contain results of executed commands
// Verify commands were executed
let response = send_command(&mut stream, "*2\r\n$3\r\nGET\r\n$4\r\nkey1\r\n").await;
assert!(response.contains("value1"));
let response = send_command(&mut stream, "*2\r\n$3\r\nGET\r\n$4\r\nkey2\r\n").await;
assert!(response.contains("value2"));
}
#[tokio::test]
async fn test_discard_transaction() {
let (mut server, port) = start_test_server("discard").await;
tokio::spawn(async move {
let listener = tokio::net::TcpListener::bind(format!("127.0.0.1:{}", port))
.await
.unwrap();
loop {
if let Ok((stream, _)) = listener.accept().await {
let _ = server.handle(stream).await;
}
}
});
sleep(Duration::from_millis(100)).await;
let mut stream = connect_to_server(port).await;
// Test MULTI
let response = send_command(&mut stream, "*1\r\n$5\r\nMULTI\r\n").await;
assert!(response.contains("OK"));
// Test queued command
let response = send_command(&mut stream, "*3\r\n$3\r\nSET\r\n$7\r\ndiscard\r\n$5\r\nvalue\r\n").await;
assert!(response.contains("QUEUED"));
// Test DISCARD
let response = send_command(&mut stream, "*1\r\n$7\r\nDISCARD\r\n").await;
assert!(response.contains("OK"));
// Verify command was not executed
let response = send_command(&mut stream, "*2\r\n$3\r\nGET\r\n$7\r\ndiscard\r\n").await;
assert!(response.contains("$-1")); // Should be NULL
}
#[tokio::test]
async fn test_type_command() {
let (mut server, port) = start_test_server("type").await;
tokio::spawn(async move {
let listener = tokio::net::TcpListener::bind(format!("127.0.0.1:{}", port))
.await
.unwrap();
loop {
if let Ok((stream, _)) = listener.accept().await {
let _ = server.handle(stream).await;
}
}
});
sleep(Duration::from_millis(100)).await;
let mut stream = connect_to_server(port).await;
// Test string type
send_command(&mut stream, "*3\r\n$3\r\nSET\r\n$6\r\nstring\r\n$5\r\nvalue\r\n").await;
let response = send_command(&mut stream, "*2\r\n$4\r\nTYPE\r\n$6\r\nstring\r\n").await;
assert!(response.contains("string"));
// Test hash type
send_command(&mut stream, "*4\r\n$4\r\nHSET\r\n$4\r\nhash\r\n$5\r\nfield\r\n$5\r\nvalue\r\n").await;
let response = send_command(&mut stream, "*2\r\n$4\r\nTYPE\r\n$4\r\nhash\r\n").await;
assert!(response.contains("hash"));
// Test non-existent key
let response = send_command(&mut stream, "*2\r\n$4\r\nTYPE\r\n$7\r\nnoexist\r\n").await;
assert!(response.contains("none"));
}
#[tokio::test]
async fn test_config_commands() {
let (mut server, port) = start_test_server("config").await;
tokio::spawn(async move {
let listener = tokio::net::TcpListener::bind(format!("127.0.0.1:{}", port))
.await
.unwrap();
loop {
if let Ok((stream, _)) = listener.accept().await {
let _ = server.handle(stream).await;
}
}
});
sleep(Duration::from_millis(100)).await;
let mut stream = connect_to_server(port).await;
// Test CONFIG GET databases
let response = send_command(&mut stream, "*3\r\n$6\r\nCONFIG\r\n$3\r\nGET\r\n$9\r\ndatabases\r\n").await;
assert!(response.contains("databases"));
assert!(response.contains("16"));
// Test CONFIG GET dir
let response = send_command(&mut stream, "*3\r\n$6\r\nCONFIG\r\n$3\r\nGET\r\n$3\r\ndir\r\n").await;
assert!(response.contains("dir"));
assert!(response.contains("/tmp/herodb_test_config"));
}
#[tokio::test]
async fn test_info_command() {
let (mut server, port) = start_test_server("info").await;
tokio::spawn(async move {
let listener = tokio::net::TcpListener::bind(format!("127.0.0.1:{}", port))
.await
.unwrap();
loop {
if let Ok((stream, _)) = listener.accept().await {
let _ = server.handle(stream).await;
}
}
});
sleep(Duration::from_millis(100)).await;
let mut stream = connect_to_server(port).await;
// Test INFO
let response = send_command(&mut stream, "*1\r\n$4\r\nINFO\r\n").await;
assert!(response.contains("redis_version"));
// Test INFO replication
let response = send_command(&mut stream, "*2\r\n$4\r\nINFO\r\n$11\r\nreplication\r\n").await;
assert!(response.contains("role:master"));
}
#[tokio::test]
async fn test_error_handling() {
let (mut server, port) = start_test_server("error").await;
tokio::spawn(async move {
let listener = tokio::net::TcpListener::bind(format!("127.0.0.1:{}", port))
.await
.unwrap();
loop {
if let Ok((stream, _)) = listener.accept().await {
let _ = server.handle(stream).await;
}
}
});
sleep(Duration::from_millis(100)).await;
let mut stream = connect_to_server(port).await;
// Test WRONGTYPE error - try to use hash command on string
send_command(&mut stream, "*3\r\n$3\r\nSET\r\n$6\r\nstring\r\n$5\r\nvalue\r\n").await;
let response = send_command(&mut stream, "*3\r\n$4\r\nHGET\r\n$6\r\nstring\r\n$5\r\nfield\r\n").await;
assert!(response.contains("WRONGTYPE"));
// Test unknown command
let response = send_command(&mut stream, "*1\r\n$7\r\nUNKNOWN\r\n").await;
assert!(response.contains("unknown cmd") || response.contains("ERR"));
// Test EXEC without MULTI
let response = send_command(&mut stream, "*1\r\n$4\r\nEXEC\r\n").await;
assert!(response.contains("ERR"));
// Test DISCARD without MULTI
let response = send_command(&mut stream, "*1\r\n$7\r\nDISCARD\r\n").await;
assert!(response.contains("ERR"));
}
#[tokio::test]
async fn test_list_operations() {
let (mut server, port) = start_test_server("list").await;
tokio::spawn(async move {
let listener = tokio::net::TcpListener::bind(format!("127.0.0.1:{}", port))
.await
.unwrap();
loop {
if let Ok((stream, _)) = listener.accept().await {
let _ = server.handle(stream).await;
}
}
});
sleep(Duration::from_millis(100)).await;
let mut stream = connect_to_server(port).await;
// Test LPUSH
let response = send_command(&mut stream, "*4\r\n$5\r\nLPUSH\r\n$4\r\nlist\r\n$1\r\na\r\n$1\r\nb\r\n").await;
assert!(response.contains("2")); // 2 elements
// Test RPUSH
let response = send_command(&mut stream, "*4\r\n$5\r\nRPUSH\r\n$4\r\nlist\r\n$1\r\nc\r\n$1\r\nd\r\n").await;
assert!(response.contains("4")); // 4 elements
// Test LLEN
let response = send_command(&mut stream, "*2\r\n$4\r\nLLEN\r\n$4\r\nlist\r\n").await;
assert!(response.contains("4"));
// Test LRANGE
let response = send_command(&mut stream, "*4\r\n$6\r\nLRANGE\r\n$4\r\nlist\r\n$1\r\n0\r\n$2\r\n-1\r\n").await;
assert_eq!(response, "*4\r\n$1\r\nb\r\n$1\r\na\r\n$1\r\nc\r\n$1\r\nd\r\n");
// Test LINDEX
let response = send_command(&mut stream, "*3\r\n$6\r\nLINDEX\r\n$4\r\nlist\r\n$1\r\n0\r\n").await;
assert_eq!(response, "$1\r\nb\r\n");
// Test LPOP
let response = send_command(&mut stream, "*2\r\n$4\r\nLPOP\r\n$4\r\nlist\r\n").await;
assert_eq!(response, "$1\r\nb\r\n");
// Test RPOP
let response = send_command(&mut stream, "*2\r\n$4\r\nRPOP\r\n$4\r\nlist\r\n").await;
assert_eq!(response, "$1\r\nd\r\n");
// Test LREM
send_command(&mut stream, "*3\r\n$5\r\nLPUSH\r\n$4\r\nlist\r\n$1\r\na\r\n").await; // list is now a, c, a
let response = send_command(&mut stream, "*4\r\n$4\r\nLREM\r\n$4\r\nlist\r\n$1\r\n1\r\n$1\r\na\r\n").await;
assert!(response.contains("1"));
// Test LTRIM
let response = send_command(&mut stream, "*4\r\n$5\r\nLTRIM\r\n$4\r\nlist\r\n$1\r\n0\r\n$1\r\n0\r\n").await;
assert!(response.contains("OK"));
let response = send_command(&mut stream, "*2\r\n$4\r\nLLEN\r\n$4\r\nlist\r\n").await;
assert!(response.contains("1"));
}

View File

@@ -0,0 +1,212 @@
use herodb::{server::Server, options::DBOption};
use std::time::Duration;
use tokio::time::sleep;
use tokio::io::{AsyncReadExt, AsyncWriteExt};
use tokio::net::TcpStream;
// Helper function to start a test server with clean data directory
async fn start_test_server(test_name: &str) -> (Server, u16) {
use std::sync::atomic::{AtomicU16, Ordering};
static PORT_COUNTER: AtomicU16 = AtomicU16::new(17000);
// Get a unique port for this test
let port = PORT_COUNTER.fetch_add(1, Ordering::SeqCst);
let test_dir = format!("/tmp/herodb_test_{}", test_name);
// Clean up any existing test data
let _ = std::fs::remove_dir_all(&test_dir);
std::fs::create_dir_all(&test_dir).unwrap();
let option = DBOption {
dir: test_dir,
port,
debug: true,
encrypt: false,
encryption_key: None,
};
let server = Server::new(option).await;
(server, port)
}
// Helper function to send Redis command and get response
async fn send_redis_command(port: u16, command: &str) -> String {
let mut stream = TcpStream::connect(format!("127.0.0.1:{}", port)).await.unwrap();
stream.write_all(command.as_bytes()).await.unwrap();
let mut buffer = [0; 1024];
let n = stream.read(&mut buffer).await.unwrap();
String::from_utf8_lossy(&buffer[..n]).to_string()
}
#[tokio::test]
async fn test_basic_redis_functionality() {
let (mut server, port) = start_test_server("basic").await;
// Start server in background with timeout
let server_handle = tokio::spawn(async move {
let listener = tokio::net::TcpListener::bind(format!("127.0.0.1:{}", port))
.await
.unwrap();
// Accept only a few connections for testing
for _ in 0..10 {
if let Ok((stream, _)) = listener.accept().await {
let _ = server.handle(stream).await;
}
}
});
sleep(Duration::from_millis(100)).await;
// Test PING
let response = send_redis_command(port, "*1\r\n$4\r\nPING\r\n").await;
assert!(response.contains("PONG"));
// Test SET
let response = send_redis_command(port, "*3\r\n$3\r\nSET\r\n$3\r\nkey\r\n$5\r\nvalue\r\n").await;
assert!(response.contains("OK"));
// Test GET
let response = send_redis_command(port, "*2\r\n$3\r\nGET\r\n$3\r\nkey\r\n").await;
assert!(response.contains("value"));
// Test HSET
let response = send_redis_command(port, "*4\r\n$4\r\nHSET\r\n$4\r\nhash\r\n$5\r\nfield\r\n$5\r\nvalue\r\n").await;
assert!(response.contains("1"));
// Test HGET
let response = send_redis_command(port, "*3\r\n$4\r\nHGET\r\n$4\r\nhash\r\n$5\r\nfield\r\n").await;
assert!(response.contains("value"));
// Test EXISTS
let response = send_redis_command(port, "*2\r\n$6\r\nEXISTS\r\n$3\r\nkey\r\n").await;
assert!(response.contains("1"));
// Test TTL
let response = send_redis_command(port, "*2\r\n$3\r\nTTL\r\n$3\r\nkey\r\n").await;
assert!(response.contains("-1")); // No expiration
// Test TYPE
let response = send_redis_command(port, "*2\r\n$4\r\nTYPE\r\n$3\r\nkey\r\n").await;
assert!(response.contains("string"));
// Test QUIT to close connection gracefully
let response = send_redis_command(port, "*1\r\n$4\r\nQUIT\r\n").await;
assert!(response.contains("OK"));
// Stop the server
server_handle.abort();
println!("✅ All basic Redis functionality tests passed!");
}
#[tokio::test]
async fn test_hash_operations() {
let (mut server, port) = start_test_server("hash_ops").await;
// Start server in background with timeout
let server_handle = tokio::spawn(async move {
let listener = tokio::net::TcpListener::bind(format!("127.0.0.1:{}", port))
.await
.unwrap();
// Accept only a few connections for testing
for _ in 0..5 {
if let Ok((stream, _)) = listener.accept().await {
let _ = server.handle(stream).await;
}
}
});
sleep(Duration::from_millis(100)).await;
// Test HSET multiple fields
let response = send_redis_command(port, "*6\r\n$4\r\nHSET\r\n$4\r\nhash\r\n$6\r\nfield1\r\n$6\r\nvalue1\r\n$6\r\nfield2\r\n$6\r\nvalue2\r\n").await;
assert!(response.contains("2")); // 2 new fields
// Test HGETALL
let response = send_redis_command(port, "*2\r\n$7\r\nHGETALL\r\n$4\r\nhash\r\n").await;
assert!(response.contains("field1"));
assert!(response.contains("value1"));
assert!(response.contains("field2"));
assert!(response.contains("value2"));
// Test HEXISTS
let response = send_redis_command(port, "*3\r\n$7\r\nHEXISTS\r\n$4\r\nhash\r\n$6\r\nfield1\r\n").await;
assert!(response.contains("1"));
// Test HLEN
let response = send_redis_command(port, "*2\r\n$4\r\nHLEN\r\n$4\r\nhash\r\n").await;
assert!(response.contains("2"));
// Test HSCAN
let response = send_redis_command(port, "*7\r\n$5\r\nHSCAN\r\n$4\r\nhash\r\n$1\r\n0\r\n$5\r\nMATCH\r\n$1\r\n*\r\n$5\r\nCOUNT\r\n$2\r\n10\r\n").await;
assert!(response.contains("field1"));
assert!(response.contains("value1"));
assert!(response.contains("field2"));
assert!(response.contains("value2"));
// Stop the server
server_handle.abort();
println!("✅ All hash operations tests passed!");
}
#[tokio::test]
async fn test_transaction_operations() {
let (mut server, port) = start_test_server("transactions").await;
// Start server in background with timeout
let server_handle = tokio::spawn(async move {
let listener = tokio::net::TcpListener::bind(format!("127.0.0.1:{}", port))
.await
.unwrap();
// Accept only a few connections for testing
for _ in 0..5 {
if let Ok((stream, _)) = listener.accept().await {
let _ = server.handle(stream).await;
}
}
});
sleep(Duration::from_millis(100)).await;
// Use a single connection for the transaction
let mut stream = TcpStream::connect(format!("127.0.0.1:{}", port)).await.unwrap();
// Test MULTI
stream.write_all("*1\r\n$5\r\nMULTI\r\n".as_bytes()).await.unwrap();
let mut buffer = [0; 1024];
let n = stream.read(&mut buffer).await.unwrap();
let response = String::from_utf8_lossy(&buffer[..n]);
assert!(response.contains("OK"));
// Test queued commands
stream.write_all("*3\r\n$3\r\nSET\r\n$4\r\nkey1\r\n$6\r\nvalue1\r\n".as_bytes()).await.unwrap();
let n = stream.read(&mut buffer).await.unwrap();
let response = String::from_utf8_lossy(&buffer[..n]);
assert!(response.contains("QUEUED"));
stream.write_all("*3\r\n$3\r\nSET\r\n$4\r\nkey2\r\n$6\r\nvalue2\r\n".as_bytes()).await.unwrap();
let n = stream.read(&mut buffer).await.unwrap();
let response = String::from_utf8_lossy(&buffer[..n]);
assert!(response.contains("QUEUED"));
// Test EXEC
stream.write_all("*1\r\n$4\r\nEXEC\r\n".as_bytes()).await.unwrap();
let n = stream.read(&mut buffer).await.unwrap();
let response = String::from_utf8_lossy(&buffer[..n]);
assert!(response.contains("OK")); // Should contain array of OK responses
// Verify commands were executed
let response = send_redis_command(port, "*2\r\n$3\r\nGET\r\n$4\r\nkey1\r\n").await;
assert!(response.contains("value1"));
// Stop the server
server_handle.abort();
println!("✅ All transaction operations tests passed!");
}

View File

@@ -0,0 +1,183 @@
use herodb::{server::Server, options::DBOption};
use std::time::Duration;
use tokio::io::{AsyncReadExt, AsyncWriteExt};
use tokio::net::TcpStream;
use tokio::time::sleep;
// Helper function to start a test server with clean data directory
async fn start_test_server(test_name: &str) -> (Server, u16) {
use std::sync::atomic::{AtomicU16, Ordering};
static PORT_COUNTER: AtomicU16 = AtomicU16::new(16500);
let port = PORT_COUNTER.fetch_add(1, Ordering::SeqCst);
let test_dir = format!("/tmp/herodb_simple_test_{}", test_name);
// Clean up any existing test data
let _ = std::fs::remove_dir_all(&test_dir);
std::fs::create_dir_all(&test_dir).unwrap();
let option = DBOption {
dir: test_dir,
port,
debug: false,
encrypt: false,
encryption_key: None,
};
let server = Server::new(option).await;
(server, port)
}
// Helper function to send command and get response
async fn send_command(stream: &mut TcpStream, command: &str) -> String {
stream.write_all(command.as_bytes()).await.unwrap();
let mut buffer = [0; 1024];
let n = stream.read(&mut buffer).await.unwrap();
String::from_utf8_lossy(&buffer[..n]).to_string()
}
// Helper function to connect to the test server
async fn connect_to_server(port: u16) -> TcpStream {
let mut attempts = 0;
loop {
match TcpStream::connect(format!("127.0.0.1:{}", port)).await {
Ok(stream) => return stream,
Err(_) if attempts < 10 => {
attempts += 1;
sleep(Duration::from_millis(100)).await;
}
Err(e) => panic!("Failed to connect to test server: {}", e),
}
}
}
#[tokio::test]
async fn test_basic_ping_simple() {
let (mut server, port) = start_test_server("ping").await;
// Start server in background
tokio::spawn(async move {
let listener = tokio::net::TcpListener::bind(format!("127.0.0.1:{}", port))
.await
.unwrap();
loop {
if let Ok((stream, _)) = listener.accept().await {
let _ = server.handle(stream).await;
}
}
});
sleep(Duration::from_millis(200)).await;
let mut stream = connect_to_server(port).await;
let response = send_command(&mut stream, "*1\r\n$4\r\nPING\r\n").await;
assert!(response.contains("PONG"));
}
#[tokio::test]
async fn test_hset_clean_db() {
let (mut server, port) = start_test_server("hset_clean").await;
// Start server in background
tokio::spawn(async move {
let listener = tokio::net::TcpListener::bind(format!("127.0.0.1:{}", port))
.await
.unwrap();
loop {
if let Ok((stream, _)) = listener.accept().await {
let _ = server.handle(stream).await;
}
}
});
sleep(Duration::from_millis(200)).await;
let mut stream = connect_to_server(port).await;
// Test HSET - should return 1 for new field
let response = send_command(&mut stream, "*4\r\n$4\r\nHSET\r\n$4\r\nhash\r\n$6\r\nfield1\r\n$6\r\nvalue1\r\n").await;
println!("HSET response: {}", response);
assert!(response.contains("1"), "Expected HSET to return 1, got: {}", response);
// Test HGET
let response = send_command(&mut stream, "*3\r\n$4\r\nHGET\r\n$4\r\nhash\r\n$6\r\nfield1\r\n").await;
println!("HGET response: {}", response);
assert!(response.contains("value1"));
}
#[tokio::test]
async fn test_type_command_simple() {
let (mut server, port) = start_test_server("type").await;
// Start server in background
tokio::spawn(async move {
let listener = tokio::net::TcpListener::bind(format!("127.0.0.1:{}", port))
.await
.unwrap();
loop {
if let Ok((stream, _)) = listener.accept().await {
let _ = server.handle(stream).await;
}
}
});
sleep(Duration::from_millis(200)).await;
let mut stream = connect_to_server(port).await;
// Test string type
send_command(&mut stream, "*3\r\n$3\r\nSET\r\n$6\r\nstring\r\n$5\r\nvalue\r\n").await;
let response = send_command(&mut stream, "*2\r\n$4\r\nTYPE\r\n$6\r\nstring\r\n").await;
println!("TYPE string response: {}", response);
assert!(response.contains("string"));
// Test hash type
send_command(&mut stream, "*4\r\n$4\r\nHSET\r\n$4\r\nhash\r\n$5\r\nfield\r\n$5\r\nvalue\r\n").await;
let response = send_command(&mut stream, "*2\r\n$4\r\nTYPE\r\n$4\r\nhash\r\n").await;
println!("TYPE hash response: {}", response);
assert!(response.contains("hash"));
// Test non-existent key
let response = send_command(&mut stream, "*2\r\n$4\r\nTYPE\r\n$7\r\nnoexist\r\n").await;
println!("TYPE noexist response: {}", response);
assert!(response.contains("none"), "Expected 'none' for non-existent key, got: {}", response);
}
#[tokio::test]
async fn test_hexists_simple() {
let (mut server, port) = start_test_server("hexists").await;
// Start server in background
tokio::spawn(async move {
let listener = tokio::net::TcpListener::bind(format!("127.0.0.1:{}", port))
.await
.unwrap();
loop {
if let Ok((stream, _)) = listener.accept().await {
let _ = server.handle(stream).await;
}
}
});
sleep(Duration::from_millis(200)).await;
let mut stream = connect_to_server(port).await;
// Set up hash
send_command(&mut stream, "*4\r\n$4\r\nHSET\r\n$4\r\nhash\r\n$6\r\nfield1\r\n$6\r\nvalue1\r\n").await;
// Test HEXISTS for existing field
let response = send_command(&mut stream, "*3\r\n$7\r\nHEXISTS\r\n$4\r\nhash\r\n$6\r\nfield1\r\n").await;
println!("HEXISTS existing field response: {}", response);
assert!(response.contains("1"));
// Test HEXISTS for non-existent field
let response = send_command(&mut stream, "*3\r\n$7\r\nHEXISTS\r\n$4\r\nhash\r\n$7\r\nnoexist\r\n").await;
println!("HEXISTS non-existent field response: {}", response);
assert!(response.contains("0"), "Expected HEXISTS to return 0 for non-existent field, got: {}", response);
}

View File

@@ -0,0 +1,10 @@
[package]
name = "libcrypto"
version = "0.1.0"
edition = "2021"
[dependencies]
chacha20poly1305 = { workspace = true }
rand = { workspace = true }
sha2 = { workspace = true }
thiserror = { workspace = true }

View File

@@ -0,0 +1,72 @@
// In crates/libcrypto/src/lib.rs
use chacha20poly1305::{
aead::{Aead, KeyInit, OsRng},
XChaCha20Poly1305, XNonce,
};
use rand::RngCore;
use sha2::{Digest, Sha256};
use thiserror::Error;
const VERSION: u8 = 1;
const NONCE_LEN: usize = 24;
const TAG_LEN: usize = 16;
#[derive(Error, Debug)]
pub enum CryptoError {
#[error("invalid format: data too short")]
Format,
#[error("unknown version: {0}")]
Version(u8),
#[error("decryption failed: wrong key or corrupted data")]
Decrypt,
}
/// Super-simple factory: new(secret) + encrypt(bytes) + decrypt(bytes)
pub struct CryptoFactory {
key: chacha20poly1305::Key,
}
impl CryptoFactory {
/// Accepts any secret bytes; turns them into a 32-byte key (SHA-256).
pub fn new<S: AsRef<[u8]>>(secret: S) -> Self {
let mut h = Sha256::new();
h.update(b"xchacha20poly1305-factory:v1"); // domain separation
h.update(secret.as_ref());
let digest = h.finalize(); // 32 bytes
let key = chacha20poly1305::Key::from_slice(&digest).to_owned();
Self { key }
}
/// Output layout: [version:1][nonce:24][ciphertext||tag]
pub fn encrypt(&self, plaintext: &[u8]) -> Vec<u8> {
let cipher = XChaCha20Poly1305::new(&self.key);
let mut nonce_bytes = [0u8; NONCE_LEN];
OsRng.fill_bytes(&mut nonce_bytes);
let nonce = XNonce::from_slice(&nonce_bytes);
let mut out = Vec::with_capacity(1 + NONCE_LEN + plaintext.len() + TAG_LEN);
out.push(VERSION);
out.extend_from_slice(&nonce_bytes);
let ct = cipher.encrypt(nonce, plaintext).expect("encrypt");
out.extend_from_slice(&ct);
out
}
pub fn decrypt(&self, blob: &[u8]) -> Result<Vec<u8>, CryptoError> {
if blob.len() < 1 + NONCE_LEN + TAG_LEN {
return Err(CryptoError::Format);
}
let ver = blob[0];
if ver != VERSION {
return Err(CryptoError::Version(ver));
}
let nonce = XNonce::from_slice(&blob[1..1 + NONCE_LEN]);
let ct = &blob[1 + NONCE_LEN..];
let cipher = XChaCha20Poly1305::new(&self.key);
cipher.decrypt(nonce, ct).map_err(|_| CryptoError::Decrypt)
}
}

View File

@@ -0,0 +1,12 @@
[package]
name = "libcryptoa"
version = "0.1.0"
edition = "2021"
[dependencies]
age = { workspace = true }
secrecy = { workspace = true }
ed25519-dalek = { workspace = true }
base64 = { workspace = true }
rand = { workspace = true }
thiserror = { workspace = true }

View File

@@ -0,0 +1,100 @@
// In crates/libcryptoa/src/lib.rs
use std::str::FromStr;
use age::{Decryptor, Encryptor, x25519};
use base64::{engine::general_purpose::STANDARD as B64, Engine as _};
use ed25519_dalek::{Signature, Signer, SigningKey, Verifier, VerifyingKey};
use secrecy::ExposeSecret;
use thiserror::Error;
#[derive(Error, Debug)]
pub enum AsymmetricCryptoError {
#[error("key parsing failed")]
ParseKey,
#[error("age crypto error: {0}")]
Age(String),
#[error("invalid utf-8 in plaintext")]
Utf8,
#[error("invalid signature length")]
SignatureLen,
#[error("signature verification failed")]
Verify,
#[error("base64 decoding failed: {0}")]
Base64(#[from] base64::DecodeError),
#[error("io error: {0}")]
Io(#[from] std::io::Error),
}
fn parse_recipient(s: &str) -> Result<x25519::Recipient, AsymmetricCryptoError> {
x25519::Recipient::from_str(s).map_err(|_| AsymmetricCryptoError::ParseKey)
}
fn parse_identity(s: &str) -> Result<x25519::Identity, AsymmetricCryptoError> {
x25519::Identity::from_str(s).map_err(|_| AsymmetricCryptoError::ParseKey)
}
fn parse_ed25519_signing_key(s: &str) -> Result<SigningKey, AsymmetricCryptoError> {
let bytes = B64.decode(s)?;
let key_bytes: [u8; 32] = bytes.try_into().map_err(|_| AsymmetricCryptoError::ParseKey)?;
Ok(SigningKey::from_bytes(&key_bytes))
}
fn parse_ed25519_verifying_key(s: &str) -> Result<VerifyingKey, AsymmetricCryptoError> {
let bytes = B64.decode(s)?;
let key_bytes: [u8; 32] = bytes.try_into().map_err(|_| AsymmetricCryptoError::ParseKey)?;
VerifyingKey::from_bytes(&key_bytes).map_err(|_| AsymmetricCryptoError::ParseKey)
}
pub fn gen_enc_keypair() -> (String, String) {
let id = x25519::Identity::generate();
let pk = id.to_public();
(pk.to_string(), id.to_string().expose_secret().to_string())
}
pub fn gen_sign_keypair() -> (String, String) {
let signing_key = SigningKey::generate(&mut rand::rngs::OsRng);
let verifying_key = signing_key.verifying_key();
(B64.encode(verifying_key.to_bytes()), B64.encode(signing_key.to_bytes()))
}
pub fn encrypt_b64(recipient_str: &str, msg: &str) -> Result<String, AsymmetricCryptoError> {
let recipient = parse_recipient(recipient_str)?;
let encryptor = Encryptor::with_recipients(vec![Box::new(recipient)])
.ok_or_else(|| AsymmetricCryptoError::Age("Failed to create encryptor".into()))?;
let mut encrypted = vec![];
let mut writer = encryptor.wrap_output(&mut encrypted)?;
std::io::Write::write_all(&mut writer, msg.as_bytes())?;
writer.finish()?;
Ok(B64.encode(encrypted))
}
pub fn decrypt_b64(identity_str: &str, ct_b64: &str) -> Result<String, AsymmetricCryptoError> {
let identity = parse_identity(identity_str)?;
let ct = B64.decode(ct_b64)?;
let decryptor = Decryptor::new(&ct[..]).map_err(|e| AsymmetricCryptoError::Age(e.to_string()))?;
let mut decrypted = vec![];
if let Decryptor::Recipients(d) = decryptor {
let mut reader = d.decrypt(std::iter::once(&identity as &dyn age::Identity))
.map_err(|e| AsymmetricCryptoError::Age(e.to_string()))?;
std::io::Read::read_to_end(&mut reader, &mut decrypted)?;
String::from_utf8(decrypted).map_err(|_| AsymmetricCryptoError::Utf8)
} else {
Err(AsymmetricCryptoError::Age("Passphrase decryption not supported".into()))
}
}
pub fn sign_b64(signing_secret_str: &str, msg: &str) -> Result<String, AsymmetricCryptoError> {
let signing_key = parse_ed25519_signing_key(signing_secret_str)?;
let signature = signing_key.sign(msg.as_bytes());
Ok(B64.encode(signature.to_bytes()))
}
pub fn verify_b64(verify_pub_str: &str, msg: &str, sig_b64: &str) -> Result<bool, AsymmetricCryptoError> {
let verifying_key = parse_ed25519_verifying_key(verify_pub_str)?;
let sig_bytes = B64.decode(sig_b64)?;
let signature = Signature::from_slice(&sig_bytes).map_err(|_| AsymmetricCryptoError::SignatureLen)?;
Ok(verifying_key.verify(msg.as_bytes(), &signature).is_ok())
}

View File

@@ -0,0 +1,15 @@
[package]
name = "libdbstorage"
version = "0.1.0"
edition = "2021"
[dependencies]
redb = { workspace = true }
serde = { workspace = true }
serde_json = { workspace = true }
thiserror = { workspace = true }
# Local Crate Dependencies
libcrypto = { path = "../libcrypto" }
tokio = { version = "1", features = ["full"] }
bincode = "1.3.3"

View File

@@ -4,7 +4,6 @@ use tokio::sync::mpsc;
use redb; use redb;
use bincode; use bincode;
use crate::protocol::Protocol;
// todo: more error types // todo: more error types
#[derive(Debug)] #[derive(Debug)]
@@ -81,3 +80,10 @@ impl From<tokio::sync::mpsc::error::SendError<()>> for DBError {
DBError(item.to_string().clone()) DBError(item.to_string().clone())
} }
} }
impl From<serde_json::Error> for DBError {
fn from(item: serde_json::Error) -> Self {
DBError(item.to_string())
}
}

View File

@@ -0,0 +1,119 @@
// In crates/libdbstorage/src/lib.rs
use std::{
path::Path,
time::{SystemTime, UNIX_EPOCH},
};
use libcrypto::CryptoFactory; // Correct import
use redb::{Database, TableDefinition};
use serde::{Deserialize, Serialize};
pub mod error; // Declare the error module
pub use error::DBError; // Re-export for users of this crate
// Declare storage module
pub mod storage;
// Table definitions for different Redis data types
const TYPES_TABLE: TableDefinition<&str, &str> = TableDefinition::new("types");
const STRINGS_TABLE: TableDefinition<&str, &[u8]> = TableDefinition::new("strings");
const HASHES_TABLE: TableDefinition<(&str, &str), &[u8]> = TableDefinition::new("hashes");
const LISTS_TABLE: TableDefinition<&str, &[u8]> = TableDefinition::new("lists");
const STREAMS_META_TABLE: TableDefinition<&str, &[u8]> = TableDefinition::new("streams_meta");
const STREAMS_DATA_TABLE: TableDefinition<(&str, &str), &[u8]> = TableDefinition::new("streams_data");
const ENCRYPTED_TABLE: TableDefinition<&str, u8> = TableDefinition::new("encrypted");
const EXPIRATION_TABLE: TableDefinition<&str, u64> = TableDefinition::new("expiration");
#[derive(Serialize, Deserialize, Debug, Clone)]
pub struct StreamEntry {
pub fields: Vec<(String, String)>,
}
#[derive(Serialize, Deserialize, Debug, Clone)]
pub struct ListValue {
pub elements: Vec<String>,
}
#[inline]
pub fn now_in_millis() -> u128 {
let start = SystemTime::now();
let duration_since_epoch = start.duration_since(UNIX_EPOCH).unwrap();
duration_since_epoch.as_millis()
}
pub struct Storage {
db: Database,
crypto: Option<CryptoFactory>,
}
impl Storage {
pub fn new(path: impl AsRef<Path>, should_encrypt: bool, master_key: Option<&str>) -> Result<Self, DBError> {
let db = Database::create(path)?;
// Create tables if they don't exist
let write_txn = db.begin_write()?;
{
let _ = write_txn.open_table(TYPES_TABLE)?;
let _ = write_txn.open_table(STRINGS_TABLE)?;
let _ = write_txn.open_table(HASHES_TABLE)?;
let _ = write_txn.open_table(LISTS_TABLE)?;
let _ = write_txn.open_table(STREAMS_META_TABLE)?;
let _ = write_txn.open_table(STREAMS_DATA_TABLE)?;
let _ = write_txn.open_table(ENCRYPTED_TABLE)?;
let _ = write_txn.open_table(EXPIRATION_TABLE)?;
}
write_txn.commit()?;
// Check if database was previously encrypted
let read_txn = db.begin_read()?;
let encrypted_table = read_txn.open_table(ENCRYPTED_TABLE)?;
let was_encrypted = encrypted_table.get("encrypted")?.map(|v| v.value() == 1).unwrap_or(false);
drop(read_txn);
let crypto = if should_encrypt || was_encrypted {
if let Some(key) = master_key {
Some(CryptoFactory::new(key.as_bytes()))
} else {
return Err(DBError("Encryption requested but no master key provided".to_string()));
}
} else {
None
};
// If we're enabling encryption for the first time, mark it
if should_encrypt && !was_encrypted {
let write_txn = db.begin_write()?;
{
let mut encrypted_table = write_txn.open_table(ENCRYPTED_TABLE)?;
encrypted_table.insert("encrypted", &1u8)?;
}
write_txn.commit()?;
}
Ok(Storage {
db,
crypto,
})
}
pub fn is_encrypted(&self) -> bool {
self.crypto.is_some()
}
// Helper methods for encryption
fn encrypt_if_needed(&self, data: &[u8]) -> Result<Vec<u8>, DBError> {
if let Some(crypto) = &self.crypto {
Ok(crypto.encrypt(data))
} else {
Ok(data.to_vec())
}
}
fn decrypt_if_needed(&self, data: &[u8]) -> Result<Vec<u8>, DBError> {
if let Some(crypto) = &self.crypto {
Ok(crypto.decrypt(data).map_err(|e| DBError(e.to_string()))?)
} else {
Ok(data.to_vec())
}
}
}

View File

@@ -0,0 +1,4 @@
pub mod storage_basic;
pub mod storage_hset;
pub mod storage_lists;
pub mod storage_extra;

View File

@@ -0,0 +1,218 @@
use redb::{ReadableTable};
use crate::error::DBError;
use crate::{Storage, TYPES_TABLE, STRINGS_TABLE, HASHES_TABLE, LISTS_TABLE, STREAMS_META_TABLE, STREAMS_DATA_TABLE, EXPIRATION_TABLE, now_in_millis};
impl Storage {
pub fn flushdb(&self) -> Result<(), DBError> {
let write_txn = self.db.begin_write()?;
{
let mut types_table = write_txn.open_table(TYPES_TABLE)?;
let mut strings_table = write_txn.open_table(STRINGS_TABLE)?;
let mut hashes_table = write_txn.open_table(HASHES_TABLE)?;
let mut lists_table = write_txn.open_table(LISTS_TABLE)?;
let mut streams_meta_table = write_txn.open_table(STREAMS_META_TABLE)?;
let mut streams_data_table = write_txn.open_table(STREAMS_DATA_TABLE)?;
let mut expiration_table = write_txn.open_table(EXPIRATION_TABLE)?;
// inefficient, but there is no other way
let keys: Vec<String> = types_table.iter()?.map(|item| item.unwrap().0.value().to_string()).collect();
for key in keys {
types_table.remove(key.as_str())?;
}
let keys: Vec<String> = strings_table.iter()?.map(|item| item.unwrap().0.value().to_string()).collect();
for key in keys {
strings_table.remove(key.as_str())?;
}
let keys: Vec<(String, String)> = hashes_table
.iter()?
.map(|item| {
let binding = item.unwrap();
let (k, f) = binding.0.value();
(k.to_string(), f.to_string())
})
.collect();
for (key, field) in keys {
hashes_table.remove((key.as_str(), field.as_str()))?;
}
let keys: Vec<String> = lists_table.iter()?.map(|item| item.unwrap().0.value().to_string()).collect();
for key in keys {
lists_table.remove(key.as_str())?;
}
let keys: Vec<String> = streams_meta_table.iter()?.map(|item| item.unwrap().0.value().to_string()).collect();
for key in keys {
streams_meta_table.remove(key.as_str())?;
}
let keys: Vec<(String,String)> = streams_data_table.iter()?.map(|item| {
let binding = item.unwrap();
let (key, field) = binding.0.value();
(key.to_string(), field.to_string())
}).collect();
for (key, field) in keys {
streams_data_table.remove((key.as_str(), field.as_str()))?;
}
let keys: Vec<String> = expiration_table.iter()?.map(|item| item.unwrap().0.value().to_string()).collect();
for key in keys {
expiration_table.remove(key.as_str())?;
}
}
write_txn.commit()?;
Ok(())
}
pub fn get_key_type(&self, key: &str) -> Result<Option<String>, DBError> {
let read_txn = self.db.begin_read()?;
let table = read_txn.open_table(TYPES_TABLE)?;
// Before returning type, check for expiration
if let Some(type_val) = table.get(key)? {
if type_val.value() == "string" {
let expiration_table = read_txn.open_table(EXPIRATION_TABLE)?;
if let Some(expires_at) = expiration_table.get(key)? {
if now_in_millis() > expires_at.value() as u128 {
// The key is expired, so it effectively has no type
return Ok(None);
}
}
}
Ok(Some(type_val.value().to_string()))
} else {
Ok(None)
}
}
// ✅ ENCRYPTION APPLIED: Value is encrypted/decrypted
pub fn get(&self, key: &str) -> Result<Option<String>, DBError> {
let read_txn = self.db.begin_read()?;
let types_table = read_txn.open_table(TYPES_TABLE)?;
match types_table.get(key)? {
Some(type_val) if type_val.value() == "string" => {
// Check expiration first (unencrypted)
let expiration_table = read_txn.open_table(EXPIRATION_TABLE)?;
if let Some(expires_at) = expiration_table.get(key)? {
if now_in_millis() > expires_at.value() as u128 {
drop(read_txn);
self.del(key.to_string())?;
return Ok(None);
}
}
// Get and decrypt value
let strings_table = read_txn.open_table(STRINGS_TABLE)?;
match strings_table.get(key)? {
Some(data) => {
let decrypted = self.decrypt_if_needed(data.value())?;
let value = String::from_utf8(decrypted)?;
Ok(Some(value))
}
None => Ok(None),
}
}
_ => Ok(None),
}
}
// ✅ ENCRYPTION APPLIED: Value is encrypted before storage
pub fn set(&self, key: String, value: String) -> Result<(), DBError> {
let write_txn = self.db.begin_write()?;
{
let mut types_table = write_txn.open_table(TYPES_TABLE)?;
types_table.insert(key.as_str(), "string")?;
let mut strings_table = write_txn.open_table(STRINGS_TABLE)?;
// Only encrypt the value, not expiration
let encrypted = self.encrypt_if_needed(value.as_bytes())?;
strings_table.insert(key.as_str(), encrypted.as_slice())?;
// Remove any existing expiration since this is a regular SET
let mut expiration_table = write_txn.open_table(EXPIRATION_TABLE)?;
expiration_table.remove(key.as_str())?;
}
write_txn.commit()?;
Ok(())
}
// ✅ ENCRYPTION APPLIED: Value is encrypted before storage
pub fn setx(&self, key: String, value: String, expire_ms: u128) -> Result<(), DBError> {
let write_txn = self.db.begin_write()?;
{
let mut types_table = write_txn.open_table(TYPES_TABLE)?;
types_table.insert(key.as_str(), "string")?;
let mut strings_table = write_txn.open_table(STRINGS_TABLE)?;
// Only encrypt the value
let encrypted = self.encrypt_if_needed(value.as_bytes())?;
strings_table.insert(key.as_str(), encrypted.as_slice())?;
// Store expiration separately (unencrypted)
let mut expiration_table = write_txn.open_table(EXPIRATION_TABLE)?;
let expires_at = expire_ms + now_in_millis();
expiration_table.insert(key.as_str(), &(expires_at as u64))?;
}
write_txn.commit()?;
Ok(())
}
pub fn del(&self, key: String) -> Result<(), DBError> {
let write_txn = self.db.begin_write()?;
{
let mut types_table = write_txn.open_table(TYPES_TABLE)?;
let mut strings_table = write_txn.open_table(STRINGS_TABLE)?;
let mut hashes_table: redb::Table<(&str, &str), &[u8]> = write_txn.open_table(HASHES_TABLE)?;
let mut lists_table = write_txn.open_table(LISTS_TABLE)?;
// Remove from type table
types_table.remove(key.as_str())?;
// Remove from strings table
strings_table.remove(key.as_str())?;
// Remove all hash fields for this key
let mut to_remove = Vec::new();
let mut iter = hashes_table.iter()?;
while let Some(entry) = iter.next() {
let entry = entry?;
let (hash_key, field) = entry.0.value();
if hash_key == key.as_str() {
to_remove.push((hash_key.to_string(), field.to_string()));
}
}
drop(iter);
for (hash_key, field) in to_remove {
hashes_table.remove((hash_key.as_str(), field.as_str()))?;
}
// Remove from lists table
lists_table.remove(key.as_str())?;
// Also remove expiration
let mut expiration_table = write_txn.open_table(EXPIRATION_TABLE)?;
expiration_table.remove(key.as_str())?;
}
write_txn.commit()?;
Ok(())
}
pub fn keys(&self, pattern: &str) -> Result<Vec<String>, DBError> {
let read_txn = self.db.begin_read()?;
let table = read_txn.open_table(TYPES_TABLE)?;
let mut keys = Vec::new();
let mut iter = table.iter()?;
while let Some(entry) = iter.next() {
let key = entry?.0.value().to_string();
if pattern == "*" || super::storage_extra::glob_match(pattern, &key) {
keys.push(key);
}
}
Ok(keys)
}
}

View File

@@ -0,0 +1,168 @@
use redb::{ReadableTable};
use crate::error::DBError;
use crate::{Storage, TYPES_TABLE, STRINGS_TABLE, EXPIRATION_TABLE, now_in_millis};
impl Storage {
// ✅ ENCRYPTION APPLIED: Values are decrypted after retrieval
pub fn scan(&self, cursor: u64, pattern: Option<&str>, count: Option<u64>) -> Result<(u64, Vec<(String, String)>), DBError> {
let read_txn = self.db.begin_read()?;
let types_table = read_txn.open_table(TYPES_TABLE)?;
let strings_table = read_txn.open_table(STRINGS_TABLE)?;
let mut result = Vec::new();
let mut current_cursor = 0u64;
let limit = count.unwrap_or(10) as usize;
let mut iter = types_table.iter()?;
while let Some(entry) = iter.next() {
let entry = entry?;
let key = entry.0.value().to_string();
let key_type = entry.1.value().to_string();
if current_cursor >= cursor {
// Apply pattern matching if specified
let matches = if let Some(pat) = pattern {
glob_match(pat, &key)
} else {
true
};
if matches {
// For scan, we return key-value pairs for string types
if key_type == "string" {
if let Some(data) = strings_table.get(key.as_str())? {
let decrypted = self.decrypt_if_needed(data.value())?;
let value = String::from_utf8(decrypted)?;
result.push((key, value));
} else {
result.push((key, String::new()));
}
} else {
// For non-string types, just return the key with type as value
result.push((key, key_type));
}
if result.len() >= limit {
break;
}
}
}
current_cursor += 1;
}
let next_cursor = if result.len() < limit { 0 } else { current_cursor };
Ok((next_cursor, result))
}
pub fn ttl(&self, key: &str) -> Result<i64, DBError> {
let read_txn = self.db.begin_read()?;
let types_table = read_txn.open_table(TYPES_TABLE)?;
match types_table.get(key)? {
Some(type_val) if type_val.value() == "string" => {
let expiration_table = read_txn.open_table(EXPIRATION_TABLE)?;
match expiration_table.get(key)? {
Some(expires_at) => {
let now = now_in_millis();
let expires_at_ms = expires_at.value() as u128;
if now >= expires_at_ms {
Ok(-2) // Key has expired
} else {
Ok(((expires_at_ms - now) / 1000) as i64) // TTL in seconds
}
}
None => Ok(-1), // Key exists but has no expiration
}
}
Some(_) => Ok(-1), // Key exists but is not a string (no expiration support for other types)
None => Ok(-2), // Key does not exist
}
}
pub fn exists(&self, key: &str) -> Result<bool, DBError> {
let read_txn = self.db.begin_read()?;
let types_table = read_txn.open_table(TYPES_TABLE)?;
match types_table.get(key)? {
Some(type_val) if type_val.value() == "string" => {
// Check if string key has expired
let expiration_table = read_txn.open_table(EXPIRATION_TABLE)?;
if let Some(expires_at) = expiration_table.get(key)? {
if now_in_millis() > expires_at.value() as u128 {
return Ok(false); // Key has expired
}
}
Ok(true)
}
Some(_) => Ok(true), // Key exists and is not a string
None => Ok(false), // Key does not exist
}
}
}
// Utility function for glob pattern matching
pub fn glob_match(pattern: &str, text: &str) -> bool {
if pattern == "*" {
return true;
}
// Simple glob matching - supports * and ? wildcards
let pattern_chars: Vec<char> = pattern.chars().collect();
let text_chars: Vec<char> = text.chars().collect();
fn match_recursive(pattern: &[char], text: &[char], pi: usize, ti: usize) -> bool {
if pi >= pattern.len() {
return ti >= text.len();
}
if ti >= text.len() {
// Check if remaining pattern is all '*'
return pattern[pi..].iter().all(|&c| c == '*');
}
match pattern[pi] {
'*' => {
// Try matching zero or more characters
for i in ti..=text.len() {
if match_recursive(pattern, text, pi + 1, i) {
return true;
}
}
false
}
'?' => {
// Match exactly one character
match_recursive(pattern, text, pi + 1, ti + 1)
}
c => {
// Match exact character
if text[ti] == c {
match_recursive(pattern, text, pi + 1, ti + 1)
} else {
false
}
}
}
}
match_recursive(&pattern_chars, &text_chars, 0, 0)
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_glob_match() {
assert!(glob_match("*", "anything"));
assert!(glob_match("hello", "hello"));
assert!(!glob_match("hello", "world"));
assert!(glob_match("h*o", "hello"));
assert!(glob_match("h*o", "ho"));
assert!(!glob_match("h*o", "hi"));
assert!(glob_match("h?llo", "hello"));
assert!(!glob_match("h?llo", "hllo"));
assert!(glob_match("*test*", "this_is_a_test_string"));
assert!(!glob_match("*test*", "this_is_a_string"));
}
}

View File

@@ -0,0 +1,318 @@
use redb::{ReadableTable};
use crate::error::DBError;
use crate::{Storage, TYPES_TABLE, HASHES_TABLE};
impl Storage {
// ✅ ENCRYPTION APPLIED: Values are encrypted before storage
pub fn hset(&self, key: &str, pairs: Vec<(String, String)>) -> Result<i64, DBError> {
let write_txn = self.db.begin_write()?;
let mut new_fields = 0i64;
{
let mut types_table = write_txn.open_table(TYPES_TABLE)?;
let mut hashes_table = write_txn.open_table(HASHES_TABLE)?;
// Set the type to hash
types_table.insert(key, "hash")?;
for (field, value) in pairs {
// Check if field already exists
let exists = hashes_table.get((key, field.as_str()))?.is_some();
// Encrypt the value before storing
let encrypted = self.encrypt_if_needed(value.as_bytes())?;
hashes_table.insert((key, field.as_str()), encrypted.as_slice())?;
if !exists {
new_fields += 1;
}
}
}
write_txn.commit()?;
Ok(new_fields)
}
// ✅ ENCRYPTION APPLIED: Value is decrypted after retrieval
pub fn hget(&self, key: &str, field: &str) -> Result<Option<String>, DBError> {
let read_txn = self.db.begin_read()?;
let types_table = read_txn.open_table(TYPES_TABLE)?;
match types_table.get(key)? {
Some(type_val) if type_val.value() == "hash" => {
let hashes_table = read_txn.open_table(HASHES_TABLE)?;
match hashes_table.get((key, field))? {
Some(data) => {
let decrypted = self.decrypt_if_needed(data.value())?;
let value = String::from_utf8(decrypted)?;
Ok(Some(value))
}
None => Ok(None),
}
}
_ => Ok(None),
}
}
// ✅ ENCRYPTION APPLIED: All values are decrypted after retrieval
pub fn hgetall(&self, key: &str) -> Result<Vec<(String, String)>, DBError> {
let read_txn = self.db.begin_read()?;
let types_table = read_txn.open_table(TYPES_TABLE)?;
match types_table.get(key)? {
Some(type_val) if type_val.value() == "hash" => {
let hashes_table = read_txn.open_table(HASHES_TABLE)?;
let mut result = Vec::new();
let mut iter = hashes_table.iter()?;
while let Some(entry) = iter.next() {
let entry = entry?;
let (hash_key, field) = entry.0.value();
if hash_key == key {
let decrypted = self.decrypt_if_needed(entry.1.value())?;
let value = String::from_utf8(decrypted)?;
result.push((field.to_string(), value));
}
}
Ok(result)
}
_ => Ok(Vec::new()),
}
}
pub fn hdel(&self, key: &str, fields: Vec<String>) -> Result<i64, DBError> {
let write_txn = self.db.begin_write()?;
let mut deleted = 0i64;
// First check if key exists and is a hash
let is_hash = {
let types_table = write_txn.open_table(TYPES_TABLE)?;
let result = match types_table.get(key)? {
Some(type_val) => type_val.value() == "hash",
None => false,
};
result
};
if is_hash {
let mut hashes_table = write_txn.open_table(HASHES_TABLE)?;
for field in fields {
if hashes_table.remove((key, field.as_str()))?.is_some() {
deleted += 1;
}
}
// Check if hash is now empty and remove type if so
let mut has_fields = false;
let mut iter = hashes_table.iter()?;
while let Some(entry) = iter.next() {
let entry = entry?;
let (hash_key, _) = entry.0.value();
if hash_key == key {
has_fields = true;
break;
}
}
drop(iter);
if !has_fields {
let mut types_table = write_txn.open_table(TYPES_TABLE)?;
types_table.remove(key)?;
}
}
write_txn.commit()?;
Ok(deleted)
}
pub fn hexists(&self, key: &str, field: &str) -> Result<bool, DBError> {
let read_txn = self.db.begin_read()?;
let types_table = read_txn.open_table(TYPES_TABLE)?;
match types_table.get(key)? {
Some(type_val) if type_val.value() == "hash" => {
let hashes_table = read_txn.open_table(HASHES_TABLE)?;
Ok(hashes_table.get((key, field))?.is_some())
}
_ => Ok(false),
}
}
pub fn hkeys(&self, key: &str) -> Result<Vec<String>, DBError> {
let read_txn = self.db.begin_read()?;
let types_table = read_txn.open_table(TYPES_TABLE)?;
match types_table.get(key)? {
Some(type_val) if type_val.value() == "hash" => {
let hashes_table = read_txn.open_table(HASHES_TABLE)?;
let mut result = Vec::new();
let mut iter = hashes_table.iter()?;
while let Some(entry) = iter.next() {
let entry = entry?;
let (hash_key, field) = entry.0.value();
if hash_key == key {
result.push(field.to_string());
}
}
Ok(result)
}
_ => Ok(Vec::new()),
}
}
// ✅ ENCRYPTION APPLIED: All values are decrypted after retrieval
pub fn hvals(&self, key: &str) -> Result<Vec<String>, DBError> {
let read_txn = self.db.begin_read()?;
let types_table = read_txn.open_table(TYPES_TABLE)?;
match types_table.get(key)? {
Some(type_val) if type_val.value() == "hash" => {
let hashes_table = read_txn.open_table(HASHES_TABLE)?;
let mut result = Vec::new();
let mut iter = hashes_table.iter()?;
while let Some(entry) = iter.next() {
let entry = entry?;
let (hash_key, _) = entry.0.value();
if hash_key == key {
let decrypted = self.decrypt_if_needed(entry.1.value())?;
let value = String::from_utf8(decrypted)?;
result.push(value);
}
}
Ok(result)
}
_ => Ok(Vec::new()),
}
}
pub fn hlen(&self, key: &str) -> Result<i64, DBError> {
let read_txn = self.db.begin_read()?;
let types_table = read_txn.open_table(TYPES_TABLE)?;
match types_table.get(key)? {
Some(type_val) if type_val.value() == "hash" => {
let hashes_table = read_txn.open_table(HASHES_TABLE)?;
let mut count = 0i64;
let mut iter = hashes_table.iter()?;
while let Some(entry) = iter.next() {
let entry = entry?;
let (hash_key, _) = entry.0.value();
if hash_key == key {
count += 1;
}
}
Ok(count)
}
_ => Ok(0),
}
}
// ✅ ENCRYPTION APPLIED: Values are decrypted after retrieval
pub fn hmget(&self, key: &str, fields: Vec<String>) -> Result<Vec<Option<String>>, DBError> {
let read_txn = self.db.begin_read()?;
let types_table = read_txn.open_table(TYPES_TABLE)?;
match types_table.get(key)? {
Some(type_val) if type_val.value() == "hash" => {
let hashes_table = read_txn.open_table(HASHES_TABLE)?;
let mut result = Vec::new();
for field in fields {
match hashes_table.get((key, field.as_str()))? {
Some(data) => {
let decrypted = self.decrypt_if_needed(data.value())?;
let value = String::from_utf8(decrypted)?;
result.push(Some(value));
}
None => result.push(None),
}
}
Ok(result)
}
_ => Ok(fields.into_iter().map(|_| None).collect()),
}
}
// ✅ ENCRYPTION APPLIED: Value is encrypted before storage
pub fn hsetnx(&self, key: &str, field: &str, value: &str) -> Result<bool, DBError> {
let write_txn = self.db.begin_write()?;
let mut result = false;
{
let mut types_table = write_txn.open_table(TYPES_TABLE)?;
let mut hashes_table = write_txn.open_table(HASHES_TABLE)?;
// Check if field already exists
if hashes_table.get((key, field))?.is_none() {
// Set the type to hash
types_table.insert(key, "hash")?;
// Encrypt the value before storing
let encrypted = self.encrypt_if_needed(value.as_bytes())?;
hashes_table.insert((key, field), encrypted.as_slice())?;
result = true;
}
}
write_txn.commit()?;
Ok(result)
}
// ✅ ENCRYPTION APPLIED: Values are decrypted after retrieval
pub fn hscan(&self, key: &str, cursor: u64, pattern: Option<&str>, count: Option<u64>) -> Result<(u64, Vec<(String, String)>), DBError> {
let read_txn = self.db.begin_read()?;
let types_table = read_txn.open_table(TYPES_TABLE)?;
match types_table.get(key)? {
Some(type_val) if type_val.value() == "hash" => {
let hashes_table = read_txn.open_table(HASHES_TABLE)?;
let mut result = Vec::new();
let mut current_cursor = 0u64;
let limit = count.unwrap_or(10) as usize;
let mut iter = hashes_table.iter()?;
while let Some(entry) = iter.next() {
let entry = entry?;
let (hash_key, field) = entry.0.value();
if hash_key == key {
if current_cursor >= cursor {
let field_str = field.to_string();
// Apply pattern matching if specified
let matches = if let Some(pat) = pattern {
super::storage_extra::glob_match(pat, &field_str)
} else {
true
};
if matches {
let decrypted = self.decrypt_if_needed(entry.1.value())?;
let value = String::from_utf8(decrypted)?;
result.push((field_str, value));
if result.len() >= limit {
break;
}
}
}
current_cursor += 1;
}
}
let next_cursor = if result.len() < limit { 0 } else { current_cursor };
Ok((next_cursor, result))
}
_ => Ok((0, Vec::new())),
}
}
}

View File

@@ -0,0 +1,403 @@
use redb::{ReadableTable};
use crate::error::DBError;
use crate::{Storage, TYPES_TABLE, LISTS_TABLE};
impl Storage {
// ✅ ENCRYPTION APPLIED: Elements are encrypted before storage
pub fn lpush(&self, key: &str, elements: Vec<String>) -> Result<i64, DBError> {
let write_txn = self.db.begin_write()?;
let mut _length = 0i64;
{
let mut types_table = write_txn.open_table(TYPES_TABLE)?;
let mut lists_table = write_txn.open_table(LISTS_TABLE)?;
// Set the type to list
types_table.insert(key, "list")?;
// Get current list or create empty one
let mut list: Vec<String> = match lists_table.get(key)? {
Some(data) => {
let decrypted = self.decrypt_if_needed(data.value())?;
serde_json::from_slice(&decrypted)?
}
None => Vec::new(),
};
// Add elements to the front (left)
for element in elements.into_iter().rev() {
list.insert(0, element);
}
_length = list.len() as i64;
// Encrypt and store the updated list
let serialized = serde_json::to_vec(&list)?;
let encrypted = self.encrypt_if_needed(&serialized)?;
lists_table.insert(key, encrypted.as_slice())?;
}
write_txn.commit()?;
Ok(_length)
}
// ✅ ENCRYPTION APPLIED: Elements are encrypted before storage
pub fn rpush(&self, key: &str, elements: Vec<String>) -> Result<i64, DBError> {
let write_txn = self.db.begin_write()?;
let mut _length = 0i64;
{
let mut types_table = write_txn.open_table(TYPES_TABLE)?;
let mut lists_table = write_txn.open_table(LISTS_TABLE)?;
// Set the type to list
types_table.insert(key, "list")?;
// Get current list or create empty one
let mut list: Vec<String> = match lists_table.get(key)? {
Some(data) => {
let decrypted = self.decrypt_if_needed(data.value())?;
serde_json::from_slice(&decrypted)?
}
None => Vec::new(),
};
// Add elements to the end (right)
list.extend(elements);
_length = list.len() as i64;
// Encrypt and store the updated list
let serialized = serde_json::to_vec(&list)?;
let encrypted = self.encrypt_if_needed(&serialized)?;
lists_table.insert(key, encrypted.as_slice())?;
}
write_txn.commit()?;
Ok(_length)
}
// ✅ ENCRYPTION APPLIED: Elements are decrypted after retrieval and encrypted before storage
pub fn lpop(&self, key: &str, count: u64) -> Result<Vec<String>, DBError> {
let write_txn = self.db.begin_write()?;
let mut result = Vec::new();
// First check if key exists and is a list, and get the data
let list_data = {
let types_table = write_txn.open_table(TYPES_TABLE)?;
let lists_table = write_txn.open_table(LISTS_TABLE)?;
let result = match types_table.get(key)? {
Some(type_val) if type_val.value() == "list" => {
if let Some(data) = lists_table.get(key)? {
let decrypted = self.decrypt_if_needed(data.value())?;
let list: Vec<String> = serde_json::from_slice(&decrypted)?;
Some(list)
} else {
None
}
}
_ => None,
};
result
};
if let Some(mut list) = list_data {
let pop_count = std::cmp::min(count as usize, list.len());
for _ in 0..pop_count {
if !list.is_empty() {
result.push(list.remove(0));
}
}
let mut lists_table = write_txn.open_table(LISTS_TABLE)?;
if list.is_empty() {
// Remove the key if list is empty
lists_table.remove(key)?;
let mut types_table = write_txn.open_table(TYPES_TABLE)?;
types_table.remove(key)?;
} else {
// Encrypt and store the updated list
let serialized = serde_json::to_vec(&list)?;
let encrypted = self.encrypt_if_needed(&serialized)?;
lists_table.insert(key, encrypted.as_slice())?;
}
}
write_txn.commit()?;
Ok(result)
}
// ✅ ENCRYPTION APPLIED: Elements are decrypted after retrieval and encrypted before storage
pub fn rpop(&self, key: &str, count: u64) -> Result<Vec<String>, DBError> {
let write_txn = self.db.begin_write()?;
let mut result = Vec::new();
// First check if key exists and is a list, and get the data
let list_data = {
let types_table = write_txn.open_table(TYPES_TABLE)?;
let lists_table = write_txn.open_table(LISTS_TABLE)?;
let result = match types_table.get(key)? {
Some(type_val) if type_val.value() == "list" => {
if let Some(data) = lists_table.get(key)? {
let decrypted = self.decrypt_if_needed(data.value())?;
let list: Vec<String> = serde_json::from_slice(&decrypted)?;
Some(list)
} else {
None
}
}
_ => None,
};
result
};
if let Some(mut list) = list_data {
let pop_count = std::cmp::min(count as usize, list.len());
for _ in 0..pop_count {
if !list.is_empty() {
result.push(list.pop().unwrap());
}
}
let mut lists_table = write_txn.open_table(LISTS_TABLE)?;
if list.is_empty() {
// Remove the key if list is empty
lists_table.remove(key)?;
let mut types_table = write_txn.open_table(TYPES_TABLE)?;
types_table.remove(key)?;
} else {
// Encrypt and store the updated list
let serialized = serde_json::to_vec(&list)?;
let encrypted = self.encrypt_if_needed(&serialized)?;
lists_table.insert(key, encrypted.as_slice())?;
}
}
write_txn.commit()?;
Ok(result)
}
pub fn llen(&self, key: &str) -> Result<i64, DBError> {
let read_txn = self.db.begin_read()?;
let types_table = read_txn.open_table(TYPES_TABLE)?;
match types_table.get(key)? {
Some(type_val) if type_val.value() == "list" => {
let lists_table = read_txn.open_table(LISTS_TABLE)?;
match lists_table.get(key)? {
Some(data) => {
let decrypted = self.decrypt_if_needed(data.value())?;
let list: Vec<String> = serde_json::from_slice(&decrypted)?;
Ok(list.len() as i64)
}
None => Ok(0),
}
}
_ => Ok(0),
}
}
// ✅ ENCRYPTION APPLIED: Element is decrypted after retrieval
pub fn lindex(&self, key: &str, index: i64) -> Result<Option<String>, DBError> {
let read_txn = self.db.begin_read()?;
let types_table = read_txn.open_table(TYPES_TABLE)?;
match types_table.get(key)? {
Some(type_val) if type_val.value() == "list" => {
let lists_table = read_txn.open_table(LISTS_TABLE)?;
match lists_table.get(key)? {
Some(data) => {
let decrypted = self.decrypt_if_needed(data.value())?;
let list: Vec<String> = serde_json::from_slice(&decrypted)?;
let actual_index = if index < 0 {
list.len() as i64 + index
} else {
index
};
if actual_index >= 0 && (actual_index as usize) < list.len() {
Ok(Some(list[actual_index as usize].clone()))
} else {
Ok(None)
}
}
None => Ok(None),
}
}
_ => Ok(None),
}
}
// ✅ ENCRYPTION APPLIED: Elements are decrypted after retrieval
pub fn lrange(&self, key: &str, start: i64, stop: i64) -> Result<Vec<String>, DBError> {
let read_txn = self.db.begin_read()?;
let types_table = read_txn.open_table(TYPES_TABLE)?;
match types_table.get(key)? {
Some(type_val) if type_val.value() == "list" => {
let lists_table = read_txn.open_table(LISTS_TABLE)?;
match lists_table.get(key)? {
Some(data) => {
let decrypted = self.decrypt_if_needed(data.value())?;
let list: Vec<String> = serde_json::from_slice(&decrypted)?;
if list.is_empty() {
return Ok(Vec::new());
}
let len = list.len() as i64;
let start_idx = if start < 0 { std::cmp::max(0, len + start) } else { std::cmp::min(start, len) };
let stop_idx = if stop < 0 { std::cmp::max(-1, len + stop) } else { std::cmp::min(stop, len - 1) };
if start_idx > stop_idx || start_idx >= len {
return Ok(Vec::new());
}
let start_usize = start_idx as usize;
let stop_usize = (stop_idx + 1) as usize;
Ok(list[start_usize..std::cmp::min(stop_usize, list.len())].to_vec())
}
None => Ok(Vec::new()),
}
}
_ => Ok(Vec::new()),
}
}
// ✅ ENCRYPTION APPLIED: Elements are decrypted after retrieval and encrypted before storage
pub fn ltrim(&self, key: &str, start: i64, stop: i64) -> Result<(), DBError> {
let write_txn = self.db.begin_write()?;
// First check if key exists and is a list, and get the data
let list_data = {
let types_table = write_txn.open_table(TYPES_TABLE)?;
let lists_table = write_txn.open_table(LISTS_TABLE)?;
let result = match types_table.get(key)? {
Some(type_val) if type_val.value() == "list" => {
if let Some(data) = lists_table.get(key)? {
let decrypted = self.decrypt_if_needed(data.value())?;
let list: Vec<String> = serde_json::from_slice(&decrypted)?;
Some(list)
} else {
None
}
}
_ => None,
};
result
};
if let Some(list) = list_data {
if list.is_empty() {
write_txn.commit()?;
return Ok(());
}
let len = list.len() as i64;
let start_idx = if start < 0 { std::cmp::max(0, len + start) } else { std::cmp::min(start, len) };
let stop_idx = if stop < 0 { std::cmp::max(-1, len + stop) } else { std::cmp::min(stop, len - 1) };
let mut lists_table = write_txn.open_table(LISTS_TABLE)?;
if start_idx > stop_idx || start_idx >= len {
// Remove the entire list
lists_table.remove(key)?;
let mut types_table = write_txn.open_table(TYPES_TABLE)?;
types_table.remove(key)?;
} else {
let start_usize = start_idx as usize;
let stop_usize = (stop_idx + 1) as usize;
let trimmed = list[start_usize..std::cmp::min(stop_usize, list.len())].to_vec();
if trimmed.is_empty() {
lists_table.remove(key)?;
let mut types_table = write_txn.open_table(TYPES_TABLE)?;
types_table.remove(key)?;
} else {
// Encrypt and store the trimmed list
let serialized = serde_json::to_vec(&trimmed)?;
let encrypted = self.encrypt_if_needed(&serialized)?;
lists_table.insert(key, encrypted.as_slice())?;
}
}
}
write_txn.commit()?;
Ok(())
}
// ✅ ENCRYPTION APPLIED: Elements are decrypted after retrieval and encrypted before storage
pub fn lrem(&self, key: &str, count: i64, element: &str) -> Result<i64, DBError> {
let write_txn = self.db.begin_write()?;
let mut removed = 0i64;
// First check if key exists and is a list, and get the data
let list_data = {
let types_table = write_txn.open_table(TYPES_TABLE)?;
let lists_table = write_txn.open_table(LISTS_TABLE)?;
let result = match types_table.get(key)? {
Some(type_val) if type_val.value() == "list" => {
if let Some(data) = lists_table.get(key)? {
let decrypted = self.decrypt_if_needed(data.value())?;
let list: Vec<String> = serde_json::from_slice(&decrypted)?;
Some(list)
} else {
None
}
}
_ => None,
};
result
};
if let Some(mut list) = list_data {
if count == 0 {
// Remove all occurrences
let original_len = list.len();
list.retain(|x| x != element);
removed = (original_len - list.len()) as i64;
} else if count > 0 {
// Remove first count occurrences
let mut to_remove = count as usize;
list.retain(|x| {
if x == element && to_remove > 0 {
to_remove -= 1;
removed += 1;
false
} else {
true
}
});
} else {
// Remove last |count| occurrences
let mut to_remove = (-count) as usize;
for i in (0..list.len()).rev() {
if list[i] == element && to_remove > 0 {
list.remove(i);
to_remove -= 1;
removed += 1;
}
}
}
let mut lists_table = write_txn.open_table(LISTS_TABLE)?;
if list.is_empty() {
lists_table.remove(key)?;
let mut types_table = write_txn.open_table(TYPES_TABLE)?;
types_table.remove(key)?;
} else {
// Encrypt and store the updated list
let serialized = serde_json::to_vec(&list)?;
let encrypted = self.encrypt_if_needed(&serialized)?;
lists_table.insert(key, encrypted.as_slice())?;
}
}
write_txn.commit()?;
Ok(removed)
}
}

View File

@@ -0,0 +1,9 @@
[package]
name = "supervisor"
version = "0.1.0"
edition = "2021"
[dependencies]
# The supervisor will eventually depend on the herodb crate.
# We can add this dependency now.
# herodb = { path = "../herodb" }

View File

@@ -0,0 +1,4 @@
fn main() {
println!("Hello from the supervisor crate!");
// Supervisor logic will be implemented here.
}

View File

@@ -0,0 +1,18 @@
[package]
name = "supervisorrpc"
version = "0.1.0"
edition = "2021"
[[bin]]
name = "supervisorrpc"
path = "src/main.rs"
[dependencies]
# Example dependencies for an RPC server
# axum = "0.7"
# jsonrpsee = { version = "0.22", features = ["server"] }
# openrpc-types = "0.7"
tokio = { workspace = true }
redis = { version = "0.24", features = ["tokio-comp"] }
herocrypto = { path = "../herocrypto" }

View File

@@ -0,0 +1,12 @@
// To be implemented:
// 1. Define an OpenRPC schema for supervisor functions (e.g., server status, key rotation).
// 2. Implement an HTTP/TCP server (e.g., using Axum or jsonrpsee) that serves the schema
// and handles RPC calls.
// 3. Implement support for Unix domain sockets in addition to TCP.
// 4. Use the `herocrypto` or `redis-rs` crate to interact with the main `herodb` instance.
#[tokio::main]
async fn main() {
println!("Supervisor RPC server starting... (not implemented)");
// Server setup code will go here.
}

View File

@@ -1,307 +0,0 @@
# 🔑 Redis `HSET` and Related Hash Commands
## 1. `HSET`
* **Purpose**: Set the value of one or more fields in a hash.
* **Syntax**:
```bash
HSET key field value [field value ...]
```
* **Return**:
* Integer: number of fields that were newly added.
* **RESP Protocol**:
```
*4
$4
HSET
$3
key
$5
field
$5
value
```
(If multiple field-value pairs: `*6`, `*8`, etc.)
---
## 2. `HSETNX`
* **Purpose**: Set the value of a hash field only if it does **not** exist.
* **Syntax**:
```bash
HSETNX key field value
```
* **Return**:
* `1` if field was set.
* `0` if field already exists.
* **RESP Protocol**:
```
*4
$6
HSETNX
$3
key
$5
field
$5
value
```
---
## 3. `HGET`
* **Purpose**: Get the value of a hash field.
* **Syntax**:
```bash
HGET key field
```
* **Return**:
* Bulk string (value) or `nil` if field does not exist.
* **RESP Protocol**:
```
*3
$4
HGET
$3
key
$5
field
```
---
## 4. `HGETALL`
* **Purpose**: Get all fields and values in a hash.
* **Syntax**:
```bash
HGETALL key
```
* **Return**:
* Array of `[field1, value1, field2, value2, ...]`.
* **RESP Protocol**:
```
*2
$7
HGETALL
$3
key
```
---
## 5. `HMSET` (⚠️ Deprecated, use `HSET`)
* **Purpose**: Set multiple field-value pairs.
* **Syntax**:
```bash
HMSET key field value [field value ...]
```
* **Return**:
* Always `OK`.
* **RESP Protocol**:
```
*6
$5
HMSET
$3
key
$5
field
$5
value
$5
field2
$5
value2
```
---
## 6. `HMGET`
* **Purpose**: Get values of multiple fields.
* **Syntax**:
```bash
HMGET key field [field ...]
```
* **Return**:
* Array of values (bulk strings or nils).
* **RESP Protocol**:
```
*4
$5
HMGET
$3
key
$5
field1
$5
field2
```
---
## 7. `HDEL`
* **Purpose**: Delete one or more fields from a hash.
* **Syntax**:
```bash
HDEL key field [field ...]
```
* **Return**:
* Integer: number of fields removed.
* **RESP Protocol**:
```
*3
$4
HDEL
$3
key
$5
field
```
---
## 8. `HEXISTS`
* **Purpose**: Check if a field exists.
* **Syntax**:
```bash
HEXISTS key field
```
* **Return**:
* `1` if exists, `0` if not.
* **RESP Protocol**:
```
*3
$7
HEXISTS
$3
key
$5
field
```
---
## 9. `HKEYS`
* **Purpose**: Get all field names in a hash.
* **Syntax**:
```bash
HKEYS key
```
* **Return**:
* Array of field names.
* **RESP Protocol**:
```
*2
$5
HKEYS
$3
key
```
---
## 10. `HVALS`
* **Purpose**: Get all values in a hash.
* **Syntax**:
```bash
HVALS key
```
* **Return**:
* Array of values.
* **RESP Protocol**:
```
*2
$5
HVALS
$3
key
```
---
## 11. `HLEN`
* **Purpose**: Get number of fields in a hash.
* **Syntax**:
```bash
HLEN key
```
* **Return**:
* Integer: number of fields.
* **RESP Protocol**:
```
*2
$4
HLEN
$3
key
```
## 12. `HSCAN`
* **Purpose**: Iterate fields/values of a hash (cursor-based scan).
* **Syntax**:
```bash
HSCAN key cursor [MATCH pattern] [COUNT count]
```
* **Return**:
* Array: `[new-cursor, [field1, value1, ...]]`
* **RESP Protocol**:
```
*3
$5
HSCAN
$3
key
$1
0
```

View File

@@ -1,80 +0,0 @@
========================
CODE SNIPPETS
========================
TITLE: 1PC+C Commit Strategy Vulnerability Example
DESCRIPTION: Illustrates a scenario where a partially committed transaction might appear complete due to the non-cryptographic checksum (XXH3) used in the 1PC+C commit strategy. This requires controlling page flush order, introducing a crash during fsync, and ensuring valid checksums for partially written data.
SOURCE: https://github.com/cberner/redb/blob/master/docs/design.md#_snippet_9
LANGUAGE: rust
CODE:
```
table.insert(malicious_key, malicious_value);
table.insert(good_key, good_value);
txn.commit();
```
LANGUAGE: rust
CODE:
```
table.insert(malicious_key, malicious_value);
txn.commit();
```
----------------------------------------
TITLE: Basic Key-Value Operations in redb
DESCRIPTION: Demonstrates the fundamental usage of redb for creating a database, opening a table, inserting a key-value pair, and retrieving the value within separate read and write transactions.
SOURCE: https://github.com/cberner/redb/blob/master/README.md#_snippet_0
LANGUAGE: rust
CODE:
```
use redb::{Database, Error, ReadableTable, TableDefinition};
const TABLE: TableDefinition<&str, u64> = TableDefinition::new("my_data");
fn main() -> Result<(), Error> {
let db = Database::create("my_db.redb")?;
let write_txn = db.begin_write()?;
{
let mut table = write_txn.open_table(TABLE)?;
table.insert("my_key", &123)?;
}
write_txn.commit()?;
let read_txn = db.begin_read()?;
let table = read_txn.open_table(TABLE)?;
assert_eq!(table.get("my_key")?.unwrap().value(), 123);
Ok(())
}
```
## What *redb* currently supports:
* Simple operations like creating databases, inserting key-value pairs, opening and reading tables ([GitHub][1]).
* No mention of operations such as:
* Iterating over keys with a given prefix.
* Range queries based on string prefixes.
* Specialized prefixfiltered lookups.
## implement range scans as follows
You can implement prefix-like functionality using **range scans** combined with manual checks, similar to using a `BTreeSet` in Rust:
```rust
for key in table.range(prefix..).keys() {
if !key.starts_with(prefix) {
break;
}
// process key
}
```
This pattern iterates keys starting at the prefix, and stops once a key no longer matches the prefix—this works because the keys are sorted ([GitHub][1]).

View File

@@ -1,250 +0,0 @@
Got it 👍 — lets break this down properly.
Redis has two broad classes youre asking about:
1. **Basic key-space functions** (SET, GET, DEL, EXISTS, etc.)
2. **Iteration commands** (`SCAN`, `SSCAN`, `HSCAN`, `ZSCAN`)
And for each Ill show:
* What it does
* How it works at a high level
* Its **RESP protocol implementation** (the actual wire format).
---
# 1. Basic Key-Space Commands
### `SET key value`
* Stores a string value at a key.
* Overwrites if the key already exists.
**Protocol (RESP2):**
```
*3
$3
SET
$3
foo
$3
bar
```
(client sends: array of 3 bulk strings: `["SET", "foo", "bar"]`)
**Reply:**
```
+OK
```
---
### `GET key`
* Retrieves the string value stored at the key.
* Returns `nil` if key doesnt exist.
**Protocol:**
```
*2
$3
GET
$3
foo
```
**Reply:**
```
$3
bar
```
(or `$-1` for nil)
---
### `DEL key [key ...]`
* Removes one or more keys.
* Returns number of keys actually removed.
**Protocol:**
```
*2
$3
DEL
$3
foo
```
**Reply:**
```
:1
```
(integer reply = number of deleted keys)
---
### `EXISTS key [key ...]`
* Checks if one or more keys exist.
* Returns count of existing keys.
**Protocol:**
```
*2
$6
EXISTS
$3
foo
```
**Reply:**
```
:1
```
---
### `KEYS pattern`
* Returns all keys matching a glob-style pattern.
⚠️ Not efficient in production (O(N)), better to use `SCAN`.
**Protocol:**
```
*2
$4
KEYS
$1
*
```
**Reply:**
```
*2
$3
foo
$3
bar
```
(array of bulk strings with key names)
---
# 2. Iteration Commands (`SCAN` family)
### `SCAN cursor [MATCH pattern] [COUNT n]`
* Iterates the keyspace incrementally.
* Client keeps sending back the cursor from previous call until it returns `0`.
**Protocol example:**
```
*2
$4
SCAN
$1
0
```
**Reply:**
```
*2
$1
0
*2
$3
foo
$3
bar
```
Explanation:
* First element = new cursor (`"0"` means iteration finished).
* Second element = array of keys returned in this batch.
---
### `HSCAN key cursor [MATCH pattern] [COUNT n]`
* Like `SCAN`, but iterates fields of a hash.
**Protocol:**
```
*3
$5
HSCAN
$3
myh
$1
0
```
**Reply:**
```
*2
$1
0
*4
$5
field
$5
value
$5
age
$2
42
```
(Array of alternating field/value pairs)
---
### `SSCAN key cursor [MATCH pattern] [COUNT n]`
* Iterates members of a set.
Protocol and reply structure same as SCAN.
---
### `ZSCAN key cursor [MATCH pattern] [COUNT n]`
* Iterates members of a sorted set with scores.
* Returns alternating `member`, `score`.
---
# Quick Comparison
| Command | Purpose | Return Type |
| -------- | ----------------------------- | --------------------- |
| `SET` | Store a string value | Simple string `+OK` |
| `GET` | Retrieve a string value | Bulk string / nil |
| `DEL` | Delete keys | Integer (count) |
| `EXISTS` | Check existence | Integer (count) |
| `KEYS` | List all matching keys (slow) | Array of bulk strings |
| `SCAN` | Iterate over keys (safe) | `[cursor, array]` |
| `HSCAN` | Iterate over hash fields | `[cursor, array]` |
| `SSCAN` | Iterate over set members | `[cursor, array]` |
| `ZSCAN` | Iterate over sorted set | `[cursor, array]` |

25
run_tests.sh Executable file
View File

@@ -0,0 +1,25 @@
#!/bin/bash
echo "🧪 Running HeroDB Redis Compatibility Tests"
echo "=========================================="
echo ""
echo "1⃣ Running Simple Redis Tests (4 tests)..."
echo "----------------------------------------------"
cargo test -p herodb --test simple_redis_test -- --nocapture
echo ""
echo "2⃣ Running Comprehensive Redis Integration Tests (13 tests)..."
echo "----------------------------------------------------------------"
cargo test -p herodb --test redis_integration_tests -- --nocapture
cargo test -p herodb --test redis_basic_client -- --nocapture
cargo test -p herodb --test debug_hset -- --nocapture
cargo test -p herodb --test debug_hset_simple -- --nocapture
echo ""
echo "3⃣ Running All Workspace Tests..."
echo "--------------------------------"
cargo test --workspace -- --nocapture
echo ""
echo "✅ Test execution completed!"

View File

@@ -1,499 +0,0 @@
use crate::{error::DBError, protocol::Protocol, server::Server};
#[derive(Debug, Clone)]
pub enum Cmd {
Ping,
Echo(String),
Get(String),
Set(String, String),
SetPx(String, String, u128),
SetEx(String, String, u128),
Keys,
ConfigGet(String),
Info(Option<String>),
Del(String),
Type(String),
Incr(String),
Multi,
Exec,
Discard,
// Hash commands
HSet(String, Vec<(String, String)>),
HGet(String, String),
HGetAll(String),
HDel(String, Vec<String>),
HExists(String, String),
HKeys(String),
HVals(String),
HLen(String),
HMGet(String, Vec<String>),
HSetNx(String, String, String),
Scan(u64, Option<String>, Option<u64>), // cursor, pattern, count
Unknow,
}
impl Cmd {
pub fn from(s: &str) -> Result<(Self, Protocol), DBError> {
let protocol = Protocol::from(s)?;
match protocol.clone().0 {
Protocol::Array(p) => {
let cmd = p.into_iter().map(|x| x.decode()).collect::<Vec<_>>();
if cmd.is_empty() {
return Err(DBError("cmd length is 0".to_string()));
}
Ok((
match cmd[0].to_lowercase().as_str() {
"echo" => Cmd::Echo(cmd[1].clone()),
"ping" => Cmd::Ping,
"get" => Cmd::Get(cmd[1].clone()),
"set" => {
if cmd.len() == 5 && cmd[3].to_lowercase() == "px" {
Cmd::SetPx(cmd[1].clone(), cmd[2].clone(), cmd[4].parse().unwrap())
} else if cmd.len() == 5 && cmd[3].to_lowercase() == "ex" {
Cmd::SetEx(cmd[1].clone(), cmd[2].clone(), cmd[4].parse().unwrap())
} else if cmd.len() == 3 {
Cmd::Set(cmd[1].clone(), cmd[2].clone())
} else {
return Err(DBError(format!("unsupported cmd {:?}", cmd)));
}
}
"config" => {
if cmd.len() != 3 || cmd[1].to_lowercase() != "get" {
return Err(DBError(format!("unsupported cmd {:?}", cmd)));
} else {
Cmd::ConfigGet(cmd[2].clone())
}
}
"keys" => {
if cmd.len() != 2 || cmd[1] != "*" {
return Err(DBError(format!("unsupported cmd {:?}", cmd)));
} else {
Cmd::Keys
}
}
"info" => {
let section = if cmd.len() == 2 {
Some(cmd[1].clone())
} else {
None
};
Cmd::Info(section)
}
"del" => {
if cmd.len() != 2 {
return Err(DBError(format!("unsupported cmd {:?}", cmd)));
}
Cmd::Del(cmd[1].clone())
}
"type" => {
if cmd.len() != 2 {
return Err(DBError(format!("unsupported cmd {:?}", cmd)));
}
Cmd::Type(cmd[1].clone())
}
"incr" => {
if cmd.len() != 2 {
return Err(DBError(format!("unsupported cmd {:?}", cmd)));
}
Cmd::Incr(cmd[1].clone())
}
"multi" => {
if cmd.len() != 1 {
return Err(DBError(format!("unsupported cmd {:?}", cmd)));
}
Cmd::Multi
}
"exec" => {
if cmd.len() != 1 {
return Err(DBError(format!("unsupported cmd {:?}", cmd)));
}
Cmd::Exec
}
"discard" => Cmd::Discard,
// Hash commands
"hset" => {
if cmd.len() < 4 || (cmd.len() - 2) % 2 != 0 {
return Err(DBError(format!("wrong number of arguments for HSET command")));
}
let mut pairs = Vec::new();
let mut i = 2;
while i < cmd.len() - 1 {
pairs.push((cmd[i].clone(), cmd[i + 1].clone()));
i += 2;
}
Cmd::HSet(cmd[1].clone(), pairs)
}
"hget" => {
if cmd.len() != 3 {
return Err(DBError(format!("wrong number of arguments for HGET command")));
}
Cmd::HGet(cmd[1].clone(), cmd[2].clone())
}
"hgetall" => {
if cmd.len() != 2 {
return Err(DBError(format!("wrong number of arguments for HGETALL command")));
}
Cmd::HGetAll(cmd[1].clone())
}
"hdel" => {
if cmd.len() < 3 {
return Err(DBError(format!("wrong number of arguments for HDEL command")));
}
Cmd::HDel(cmd[1].clone(), cmd[2..].to_vec())
}
"hexists" => {
if cmd.len() != 3 {
return Err(DBError(format!("wrong number of arguments for HEXISTS command")));
}
Cmd::HExists(cmd[1].clone(), cmd[2].clone())
}
"hkeys" => {
if cmd.len() != 2 {
return Err(DBError(format!("wrong number of arguments for HKEYS command")));
}
Cmd::HKeys(cmd[1].clone())
}
"hvals" => {
if cmd.len() != 2 {
return Err(DBError(format!("wrong number of arguments for HVALS command")));
}
Cmd::HVals(cmd[1].clone())
}
"hlen" => {
if cmd.len() != 2 {
return Err(DBError(format!("wrong number of arguments for HLEN command")));
}
Cmd::HLen(cmd[1].clone())
}
"hmget" => {
if cmd.len() < 3 {
return Err(DBError(format!("wrong number of arguments for HMGET command")));
}
Cmd::HMGet(cmd[1].clone(), cmd[2..].to_vec())
}
"hsetnx" => {
if cmd.len() != 4 {
return Err(DBError(format!("wrong number of arguments for HSETNX command")));
}
Cmd::HSetNx(cmd[1].clone(), cmd[2].clone(), cmd[3].clone())
}
"scan" => {
if cmd.len() < 2 {
return Err(DBError(format!("wrong number of arguments for SCAN command")));
}
let cursor = cmd[1].parse::<u64>().map_err(|_|
DBError("ERR invalid cursor".to_string()))?;
let mut pattern = None;
let mut count = None;
let mut i = 2;
while i < cmd.len() {
match cmd[i].to_lowercase().as_str() {
"match" => {
if i + 1 >= cmd.len() {
return Err(DBError("ERR syntax error".to_string()));
}
pattern = Some(cmd[i + 1].clone());
i += 2;
}
"count" => {
if i + 1 >= cmd.len() {
return Err(DBError("ERR syntax error".to_string()));
}
count = Some(cmd[i + 1].parse::<u64>().map_err(|_|
DBError("ERR value is not an integer or out of range".to_string()))?);
i += 2;
}
_ => {
return Err(DBError(format!("ERR syntax error")));
}
}
}
Cmd::Scan(cursor, pattern, count)
}
_ => Cmd::Unknow,
},
protocol.0,
))
}
_ => Err(DBError(format!(
"fail to parse as cmd for {:?}",
protocol.0
))),
}
}
pub async fn run(
&self,
server: &Server,
protocol: Protocol,
queued_cmd: &mut Option<Vec<(Cmd, Protocol)>>,
) -> Result<Protocol, DBError> {
// Handle queued commands for transactions
if queued_cmd.is_some()
&& !matches!(self, Cmd::Exec)
&& !matches!(self, Cmd::Multi)
&& !matches!(self, Cmd::Discard)
{
queued_cmd
.as_mut()
.unwrap()
.push((self.clone(), protocol.clone()));
return Ok(Protocol::SimpleString("QUEUED".to_string()));
}
match self {
Cmd::Ping => Ok(Protocol::SimpleString("PONG".to_string())),
Cmd::Echo(s) => Ok(Protocol::SimpleString(s.clone())),
Cmd::Get(k) => get_cmd(server, k).await,
Cmd::Set(k, v) => set_cmd(server, k, v).await,
Cmd::SetPx(k, v, x) => set_px_cmd(server, k, v, x).await,
Cmd::SetEx(k, v, x) => set_ex_cmd(server, k, v, x).await,
Cmd::Del(k) => del_cmd(server, k).await,
Cmd::ConfigGet(name) => config_get_cmd(name, server),
Cmd::Keys => keys_cmd(server).await,
Cmd::Info(section) => info_cmd(section),
Cmd::Type(k) => type_cmd(server, k).await,
Cmd::Incr(key) => incr_cmd(server, key).await,
Cmd::Multi => {
*queued_cmd = Some(Vec::<(Cmd, Protocol)>::new());
Ok(Protocol::SimpleString("OK".to_string()))
}
Cmd::Exec => exec_cmd(queued_cmd, server).await,
Cmd::Discard => {
if queued_cmd.is_some() {
*queued_cmd = None;
Ok(Protocol::SimpleString("OK".to_string()))
} else {
Ok(Protocol::err("ERR DISCARD without MULTI"))
}
}
// Hash commands
Cmd::HSet(key, pairs) => hset_cmd(server, key, pairs).await,
Cmd::HGet(key, field) => hget_cmd(server, key, field).await,
Cmd::HGetAll(key) => hgetall_cmd(server, key).await,
Cmd::HDel(key, fields) => hdel_cmd(server, key, fields).await,
Cmd::HExists(key, field) => hexists_cmd(server, key, field).await,
Cmd::HKeys(key) => hkeys_cmd(server, key).await,
Cmd::HVals(key) => hvals_cmd(server, key).await,
Cmd::HLen(key) => hlen_cmd(server, key).await,
Cmd::HMGet(key, fields) => hmget_cmd(server, key, fields).await,
Cmd::HSetNx(key, field, value) => hsetnx_cmd(server, key, field, value).await,
Cmd::Scan(cursor, pattern, count) => scan_cmd(server, cursor, pattern.as_deref(), count).await,
Cmd::Unknow => Ok(Protocol::err("unknown cmd")),
}
}
}
async fn exec_cmd(
queued_cmd: &mut Option<Vec<(Cmd, Protocol)>>,
server: &Server,
) -> Result<Protocol, DBError> {
if queued_cmd.is_some() {
let mut vec = Vec::new();
for (cmd, protocol) in queued_cmd.as_ref().unwrap() {
let res = Box::pin(cmd.run(server, protocol.clone(), &mut None)).await?;
vec.push(res);
}
*queued_cmd = None;
Ok(Protocol::Array(vec))
} else {
Ok(Protocol::err("ERR EXEC without MULTI"))
}
}
async fn incr_cmd(server: &Server, key: &String) -> Result<Protocol, DBError> {
let current_value = server.storage.get(key)?;
let new_value = match current_value {
Some(v) => {
match v.parse::<i64>() {
Ok(num) => num + 1,
Err(_) => return Ok(Protocol::err("ERR value is not an integer or out of range")),
}
}
None => 1,
};
server.storage.set(key.clone(), new_value.to_string())?;
Ok(Protocol::SimpleString(new_value.to_string()))
}
fn config_get_cmd(name: &String, server: &Server) -> Result<Protocol, DBError> {
match name.as_str() {
"dir" => Ok(Protocol::Array(vec![
Protocol::BulkString(name.clone()),
Protocol::BulkString(server.option.dir.clone()),
])),
"dbfilename" => Ok(Protocol::Array(vec![
Protocol::BulkString(name.clone()),
Protocol::BulkString("herodb.redb".to_string()),
])),
_ => Err(DBError(format!("unsupported config {:?}", name))),
}
}
async fn keys_cmd(server: &Server) -> Result<Protocol, DBError> {
let keys = server.storage.keys("*")?;
Ok(Protocol::Array(
keys.into_iter().map(Protocol::BulkString).collect(),
))
}
fn info_cmd(section: &Option<String>) -> Result<Protocol, DBError> {
match section {
Some(s) => match s.as_str() {
"replication" => Ok(Protocol::BulkString(
"role:master\nmaster_replid:8371b4fb1155b71f4a04d3e1bc3e18c4a990aeea\nmaster_repl_offset:0\n".to_string()
)),
_ => Err(DBError(format!("unsupported section {:?}", s))),
},
None => Ok(Protocol::BulkString("# Server\nredis_version:7.0.0\n".to_string())),
}
}
async fn type_cmd(server: &Server, k: &String) -> Result<Protocol, DBError> {
match server.storage.get_key_type(k)? {
Some(type_str) => Ok(Protocol::SimpleString(type_str)),
None => Ok(Protocol::SimpleString("none".to_string())),
}
}
async fn del_cmd(server: &Server, k: &str) -> Result<Protocol, DBError> {
server.storage.del(k.to_string())?;
Ok(Protocol::SimpleString("1".to_string()))
}
async fn set_ex_cmd(
server: &Server,
k: &str,
v: &str,
x: &u128,
) -> Result<Protocol, DBError> {
server.storage.setx(k.to_string(), v.to_string(), *x * 1000)?;
Ok(Protocol::SimpleString("OK".to_string()))
}
async fn set_px_cmd(
server: &Server,
k: &str,
v: &str,
x: &u128,
) -> Result<Protocol, DBError> {
server.storage.setx(k.to_string(), v.to_string(), *x)?;
Ok(Protocol::SimpleString("OK".to_string()))
}
async fn set_cmd(server: &Server, k: &str, v: &str) -> Result<Protocol, DBError> {
server.storage.set(k.to_string(), v.to_string())?;
Ok(Protocol::SimpleString("OK".to_string()))
}
async fn get_cmd(server: &Server, k: &str) -> Result<Protocol, DBError> {
let v = server.storage.get(k)?;
Ok(v.map_or(Protocol::Null, Protocol::BulkString))
}
// Hash command implementations
async fn hset_cmd(server: &Server, key: &str, pairs: &[(String, String)]) -> Result<Protocol, DBError> {
let new_fields = server.storage.hset(key, pairs)?;
Ok(Protocol::SimpleString(new_fields.to_string()))
}
async fn hget_cmd(server: &Server, key: &str, field: &str) -> Result<Protocol, DBError> {
match server.storage.hget(key, field) {
Ok(Some(value)) => Ok(Protocol::BulkString(value)),
Ok(None) => Ok(Protocol::Null),
Err(e) => Ok(Protocol::err(&e.0)),
}
}
async fn hgetall_cmd(server: &Server, key: &str) -> Result<Protocol, DBError> {
match server.storage.hgetall(key) {
Ok(pairs) => {
let mut result = Vec::new();
for (field, value) in pairs {
result.push(Protocol::BulkString(field));
result.push(Protocol::BulkString(value));
}
Ok(Protocol::Array(result))
}
Err(e) => Ok(Protocol::err(&e.0)),
}
}
async fn hdel_cmd(server: &Server, key: &str, fields: &[String]) -> Result<Protocol, DBError> {
match server.storage.hdel(key, fields) {
Ok(deleted) => Ok(Protocol::SimpleString(deleted.to_string())),
Err(e) => Ok(Protocol::err(&e.0)),
}
}
async fn hexists_cmd(server: &Server, key: &str, field: &str) -> Result<Protocol, DBError> {
match server.storage.hexists(key, field) {
Ok(exists) => Ok(Protocol::SimpleString(if exists { "1" } else { "0" }.to_string())),
Err(e) => Ok(Protocol::err(&e.0)),
}
}
async fn hkeys_cmd(server: &Server, key: &str) -> Result<Protocol, DBError> {
match server.storage.hkeys(key) {
Ok(keys) => Ok(Protocol::Array(
keys.into_iter().map(Protocol::BulkString).collect(),
)),
Err(e) => Ok(Protocol::err(&e.0)),
}
}
async fn hvals_cmd(server: &Server, key: &str) -> Result<Protocol, DBError> {
match server.storage.hvals(key) {
Ok(values) => Ok(Protocol::Array(
values.into_iter().map(Protocol::BulkString).collect(),
)),
Err(e) => Ok(Protocol::err(&e.0)),
}
}
async fn hlen_cmd(server: &Server, key: &str) -> Result<Protocol, DBError> {
match server.storage.hlen(key) {
Ok(len) => Ok(Protocol::SimpleString(len.to_string())),
Err(e) => Ok(Protocol::err(&e.0)),
}
}
async fn hmget_cmd(server: &Server, key: &str, fields: &[String]) -> Result<Protocol, DBError> {
match server.storage.hmget(key, fields) {
Ok(values) => {
let result: Vec<Protocol> = values
.into_iter()
.map(|v| v.map_or(Protocol::Null, Protocol::BulkString))
.collect();
Ok(Protocol::Array(result))
}
Err(e) => Ok(Protocol::err(&e.0)),
}
}
async fn hsetnx_cmd(server: &Server, key: &str, field: &str, value: &str) -> Result<Protocol, DBError> {
match server.storage.hsetnx(key, field, value) {
Ok(was_set) => Ok(Protocol::SimpleString(if was_set { "1" } else { "0" }.to_string())),
Err(e) => Ok(Protocol::err(&e.0)),
}
}
async fn scan_cmd(server: &Server, cursor: &u64, pattern: Option<&str>, count: &Option<u64>) -> Result<Protocol, DBError> {
match server.storage.scan(*cursor, pattern, *count) {
Ok((next_cursor, keys)) => {
let mut result = Vec::new();
result.push(Protocol::BulkString(next_cursor.to_string()));
result.push(Protocol::Array(
keys.into_iter().map(Protocol::BulkString).collect(),
));
Ok(Protocol::Array(result))
}
Err(e) => Ok(Protocol::err(&e.0)),
}
}

View File

@@ -1,6 +0,0 @@
mod cmd;
pub mod error;
pub mod options;
mod protocol;
pub mod server;
mod storage;

View File

@@ -1,5 +0,0 @@
#[derive(Clone)]
pub struct DBOption {
pub dir: String,
pub port: u16,
}

View File

@@ -1,68 +0,0 @@
use core::str;
use std::path::PathBuf;
use std::sync::Arc;
use tokio::io::AsyncReadExt;
use tokio::io::AsyncWriteExt;
use crate::cmd::Cmd;
use crate::error::DBError;
use crate::options;
use crate::protocol::Protocol;
use crate::storage::Storage;
#[derive(Clone)]
pub struct Server {
pub storage: Arc<Storage>,
pub option: options::DBOption,
}
impl Server {
pub async fn new(option: options::DBOption) -> Self {
// Create database file path with fixed filename
let db_file_path = PathBuf::from(option.dir.clone()).join("herodb.redb");
println!("will open db file path: {}", db_file_path.display());
// Initialize storage with redb
let storage = Storage::new(db_file_path).expect("Failed to initialize storage");
Server {
storage: Arc::new(storage),
option,
}
}
pub async fn handle(
&mut self,
mut stream: tokio::net::TcpStream,
) -> Result<(), DBError> {
let mut buf = [0; 512];
let mut queued_cmd: Option<Vec<(Cmd, Protocol)>> = None;
loop {
if let Ok(len) = stream.read(&mut buf).await {
if len == 0 {
println!("[handle] connection closed");
return Ok(());
}
let s = str::from_utf8(&buf[..len])?;
let (cmd, protocol) =
Cmd::from(s).unwrap_or((Cmd::Unknow, Protocol::err("unknow cmd")));
println!("got command: {:?}, protocol: {:?}", cmd, protocol);
let res = cmd
.run(self, protocol, &mut queued_cmd)
.await
.unwrap_or(Protocol::err("unknow cmd"));
print!("queued cmd {:?}", queued_cmd);
println!("going to send response {}", res.encode());
_ = stream.write(res.encode().as_bytes()).await?;
} else {
println!("[handle] going to break");
break;
}
}
Ok(())
}
}

View File

@@ -1,509 +0,0 @@
use std::{
path::Path,
time::{SystemTime, UNIX_EPOCH},
};
use redb::{Database, Error, ReadableTable, Table, TableDefinition, WriteTransaction, ReadTransaction};
use serde::{Deserialize, Serialize};
use crate::error::DBError;
// Table definitions for different Redis data types
const TYPES_TABLE: TableDefinition<&str, &str> = TableDefinition::new("types");
const STRINGS_TABLE: TableDefinition<&str, &[u8]> = TableDefinition::new("strings");
const HASHES_TABLE: TableDefinition<(&str, &str), &str> = TableDefinition::new("hashes");
const STREAMS_META_TABLE: TableDefinition<&str, &[u8]> = TableDefinition::new("streams_meta");
const STREAMS_DATA_TABLE: TableDefinition<(&str, &str), &[u8]> = TableDefinition::new("streams_data");
#[derive(Serialize, Deserialize, Debug, Clone)]
pub struct StringValue {
pub value: String,
pub expires_at_ms: Option<u128>,
}
#[derive(Serialize, Deserialize, Debug, Clone)]
pub struct StreamEntry {
pub fields: Vec<(String, String)>,
}
#[inline]
pub fn now_in_millis() -> u128 {
let start = SystemTime::now();
let duration_since_epoch = start.duration_since(UNIX_EPOCH).unwrap();
duration_since_epoch.as_millis()
}
pub struct Storage {
db: Database,
}
impl Storage {
pub fn new(path: impl AsRef<Path>) -> Result<Self, DBError> {
let db = Database::create(path)?;
// Create tables if they don't exist
let write_txn = db.begin_write()?;
{
let _ = write_txn.open_table(TYPES_TABLE)?;
let _ = write_txn.open_table(STRINGS_TABLE)?;
let _ = write_txn.open_table(HASHES_TABLE)?;
let _ = write_txn.open_table(STREAMS_META_TABLE)?;
let _ = write_txn.open_table(STREAMS_DATA_TABLE)?;
}
write_txn.commit()?;
Ok(Storage { db })
}
pub fn get_key_type(&self, key: &str) -> Result<Option<String>, DBError> {
let read_txn = self.db.begin_read()?;
let table = read_txn.open_table(TYPES_TABLE)?;
match table.get(key)? {
Some(type_val) => Ok(Some(type_val.value().to_string())),
None => Ok(None),
}
}
pub fn get(&self, key: &str) -> Result<Option<String>, DBError> {
let read_txn = self.db.begin_read()?;
// Check if key exists and is of string type
let types_table = read_txn.open_table(TYPES_TABLE)?;
match types_table.get(key)? {
Some(type_val) if type_val.value() == "string" => {
let strings_table = read_txn.open_table(STRINGS_TABLE)?;
match strings_table.get(key)? {
Some(data) => {
let string_value: StringValue = bincode::deserialize(data.value())?;
// Check if expired
if let Some(expires_at) = string_value.expires_at_ms {
if now_in_millis() > expires_at {
// Key expired, remove it
drop(read_txn);
self.del(key.to_string())?;
return Ok(None);
}
}
Ok(Some(string_value.value))
}
None => Ok(None),
}
}
_ => Ok(None),
}
}
pub fn set(&self, key: String, value: String) -> Result<(), DBError> {
let write_txn = self.db.begin_write()?;
{
let mut types_table = write_txn.open_table(TYPES_TABLE)?;
types_table.insert(key.as_str(), "string")?;
let mut strings_table = write_txn.open_table(STRINGS_TABLE)?;
let string_value = StringValue {
value,
expires_at_ms: None,
};
let serialized = bincode::serialize(&string_value)?;
strings_table.insert(key.as_str(), serialized.as_slice())?;
}
write_txn.commit()?;
Ok(())
}
pub fn setx(&self, key: String, value: String, expire_ms: u128) -> Result<(), DBError> {
let write_txn = self.db.begin_write()?;
{
let mut types_table = write_txn.open_table(TYPES_TABLE)?;
types_table.insert(key.as_str(), "string")?;
let mut strings_table = write_txn.open_table(STRINGS_TABLE)?;
let string_value = StringValue {
value,
expires_at_ms: Some(expire_ms + now_in_millis()),
};
let serialized = bincode::serialize(&string_value)?;
strings_table.insert(key.as_str(), serialized.as_slice())?;
}
write_txn.commit()?;
Ok(())
}
pub fn del(&self, key: String) -> Result<(), DBError> {
let write_txn = self.db.begin_write()?;
{
let mut types_table = write_txn.open_table(TYPES_TABLE)?;
let mut strings_table = write_txn.open_table(STRINGS_TABLE)?;
let mut hashes_table = write_txn.open_table(HASHES_TABLE)?;
// Remove from type table
types_table.remove(key.as_str())?;
// Remove from strings table
strings_table.remove(key.as_str())?;
// Remove all hash fields for this key
let mut to_remove = Vec::new();
let mut iter = hashes_table.iter()?;
while let Some(entry) = iter.next() {
let entry = entry?;
let (hash_key, field) = entry.0.value();
if hash_key == key.as_str() {
to_remove.push((hash_key.to_string(), field.to_string()));
}
}
drop(iter);
for (hash_key, field) in to_remove {
hashes_table.remove((hash_key.as_str(), field.as_str()))?;
}
}
write_txn.commit()?;
Ok(())
}
pub fn keys(&self, pattern: &str) -> Result<Vec<String>, DBError> {
let read_txn = self.db.begin_read()?;
let table = read_txn.open_table(TYPES_TABLE)?;
let mut keys = Vec::new();
let mut iter = table.iter()?;
while let Some(entry) = iter.next() {
let key = entry?.0.value().to_string();
if pattern == "*" || key.contains(pattern) {
keys.push(key);
}
}
Ok(keys)
}
// Hash operations
pub fn hset(&self, key: &str, pairs: &[(String, String)]) -> Result<u64, DBError> {
let write_txn = self.db.begin_write()?;
let mut new_fields = 0u64;
{
let mut types_table = write_txn.open_table(TYPES_TABLE)?;
let mut hashes_table = write_txn.open_table(HASHES_TABLE)?;
// Check if key exists and is of correct type
let existing_type = match types_table.get(key)? {
Some(type_val) => Some(type_val.value().to_string()),
None => None,
};
match existing_type {
Some(ref type_str) if type_str != "hash" => {
return Err(DBError("WRONGTYPE Operation against a key holding the wrong kind of value".to_string()));
}
None => {
// Set type to hash
types_table.insert(key, "hash")?;
}
_ => {}
}
for (field, value) in pairs {
let existed = hashes_table.get((key, field.as_str()))?.is_some();
hashes_table.insert((key, field.as_str()), value.as_str())?;
if !existed {
new_fields += 1;
}
}
}
write_txn.commit()?;
Ok(new_fields)
}
pub fn hget(&self, key: &str, field: &str) -> Result<Option<String>, DBError> {
let read_txn = self.db.begin_read()?;
// Check type
let types_table = read_txn.open_table(TYPES_TABLE)?;
match types_table.get(key)? {
Some(type_val) if type_val.value() == "hash" => {
let hashes_table = read_txn.open_table(HASHES_TABLE)?;
match hashes_table.get((key, field))? {
Some(value) => Ok(Some(value.value().to_string())),
None => Ok(None),
}
}
Some(_) => Err(DBError("WRONGTYPE Operation against a key holding the wrong kind of value".to_string())),
None => Ok(None),
}
}
pub fn hgetall(&self, key: &str) -> Result<Vec<(String, String)>, DBError> {
let read_txn = self.db.begin_read()?;
// Check type
let types_table = read_txn.open_table(TYPES_TABLE)?;
match types_table.get(key)? {
Some(type_val) if type_val.value() == "hash" => {
let hashes_table = read_txn.open_table(HASHES_TABLE)?;
let mut result = Vec::new();
let mut iter = hashes_table.iter()?;
while let Some(entry) = iter.next() {
let entry = entry?;
let (hash_key, field) = entry.0.value();
let value = entry.1.value();
if hash_key == key {
result.push((field.to_string(), value.to_string()));
}
}
Ok(result)
}
Some(_) => Err(DBError("WRONGTYPE Operation against a key holding the wrong kind of value".to_string())),
None => Ok(Vec::new()),
}
}
pub fn hdel(&self, key: &str, fields: &[String]) -> Result<u64, DBError> {
let write_txn = self.db.begin_write()?;
let mut deleted = 0u64;
{
let types_table = write_txn.open_table(TYPES_TABLE)?;
let key_type = types_table.get(key)?;
match key_type {
Some(type_val) if type_val.value() == "hash" => {
let mut hashes_table = write_txn.open_table(HASHES_TABLE)?;
for field in fields {
if hashes_table.remove((key, field.as_str()))?.is_some() {
deleted += 1;
}
}
}
Some(_) => return Err(DBError("WRONGTYPE Operation against a key holding the wrong kind of value".to_string())),
None => {}
}
}
write_txn.commit()?;
Ok(deleted)
}
pub fn hexists(&self, key: &str, field: &str) -> Result<bool, DBError> {
let read_txn = self.db.begin_read()?;
let types_table = read_txn.open_table(TYPES_TABLE)?;
match types_table.get(key)? {
Some(type_val) if type_val.value() == "hash" => {
let hashes_table = read_txn.open_table(HASHES_TABLE)?;
Ok(hashes_table.get((key, field))?.is_some())
}
Some(_) => Err(DBError("WRONGTYPE Operation against a key holding the wrong kind of value".to_string())),
None => Ok(false),
}
}
pub fn hkeys(&self, key: &str) -> Result<Vec<String>, DBError> {
let read_txn = self.db.begin_read()?;
let types_table = read_txn.open_table(TYPES_TABLE)?;
match types_table.get(key)? {
Some(type_val) if type_val.value() == "hash" => {
let hashes_table = read_txn.open_table(HASHES_TABLE)?;
let mut result = Vec::new();
let mut iter = hashes_table.iter()?;
while let Some(entry) = iter.next() {
let entry = entry?;
let (hash_key, field) = entry.0.value();
if hash_key == key {
result.push(field.to_string());
}
}
Ok(result)
}
Some(_) => Err(DBError("WRONGTYPE Operation against a key holding the wrong kind of value".to_string())),
None => Ok(Vec::new()),
}
}
pub fn hvals(&self, key: &str) -> Result<Vec<String>, DBError> {
let read_txn = self.db.begin_read()?;
let types_table = read_txn.open_table(TYPES_TABLE)?;
match types_table.get(key)? {
Some(type_val) if type_val.value() == "hash" => {
let hashes_table = read_txn.open_table(HASHES_TABLE)?;
let mut result = Vec::new();
let mut iter = hashes_table.iter()?;
while let Some(entry) = iter.next() {
let entry = entry?;
let (hash_key, _) = entry.0.value();
let value = entry.1.value();
if hash_key == key {
result.push(value.to_string());
}
}
Ok(result)
}
Some(_) => Err(DBError("WRONGTYPE Operation against a key holding the wrong kind of value".to_string())),
None => Ok(Vec::new()),
}
}
pub fn hlen(&self, key: &str) -> Result<u64, DBError> {
let read_txn = self.db.begin_read()?;
let types_table = read_txn.open_table(TYPES_TABLE)?;
match types_table.get(key)? {
Some(type_val) if type_val.value() == "hash" => {
let hashes_table = read_txn.open_table(HASHES_TABLE)?;
let mut count = 0u64;
let mut iter = hashes_table.iter()?;
while let Some(entry) = iter.next() {
let entry = entry?;
let (hash_key, _) = entry.0.value();
if hash_key == key {
count += 1;
}
}
Ok(count)
}
Some(_) => Err(DBError("WRONGTYPE Operation against a key holding the wrong kind of value".to_string())),
None => Ok(0),
}
}
pub fn hmget(&self, key: &str, fields: &[String]) -> Result<Vec<Option<String>>, DBError> {
let read_txn = self.db.begin_read()?;
let types_table = read_txn.open_table(TYPES_TABLE)?;
match types_table.get(key)? {
Some(type_val) if type_val.value() == "hash" => {
let hashes_table = read_txn.open_table(HASHES_TABLE)?;
let mut result = Vec::new();
for field in fields {
match hashes_table.get((key, field.as_str()))? {
Some(value) => result.push(Some(value.value().to_string())),
None => result.push(None),
}
}
Ok(result)
}
Some(_) => Err(DBError("WRONGTYPE Operation against a key holding the wrong kind of value".to_string())),
None => Ok(fields.iter().map(|_| None).collect()),
}
}
pub fn hsetnx(&self, key: &str, field: &str, value: &str) -> Result<bool, DBError> {
let write_txn = self.db.begin_write()?;
let mut result = false;
{
let mut types_table = write_txn.open_table(TYPES_TABLE)?;
let mut hashes_table = write_txn.open_table(HASHES_TABLE)?;
// Check if key exists and is of correct type
let existing_type = match types_table.get(key)? {
Some(type_val) => Some(type_val.value().to_string()),
None => None,
};
match existing_type {
Some(ref type_str) if type_str != "hash" => {
return Err(DBError("WRONGTYPE Operation against a key holding the wrong kind of value".to_string()));
}
None => {
// Set type to hash
types_table.insert(key, "hash")?;
}
_ => {}
}
// Check if field already exists
if hashes_table.get((key, field))?.is_none() {
hashes_table.insert((key, field), value)?;
result = true;
}
}
write_txn.commit()?;
Ok(result)
}
pub fn scan(&self, cursor: u64, pattern: Option<&str>, count: Option<u64>) -> Result<(u64, Vec<String>), DBError> {
let read_txn = self.db.begin_read()?;
let table = read_txn.open_table(TYPES_TABLE)?;
let count = count.unwrap_or(10); // Default count is 10
let mut keys = Vec::new();
let mut current_cursor = 0u64;
let mut returned_keys = 0u64;
let mut iter = table.iter()?;
while let Some(entry) = iter.next() {
let key = entry?.0.value().to_string();
// Skip keys until we reach the cursor position
if current_cursor < cursor {
current_cursor += 1;
continue;
}
// Check if key matches pattern
let matches = match pattern {
Some(pat) => {
if pat == "*" {
true
} else if pat.contains('*') {
// Simple glob pattern matching
let pattern_parts: Vec<&str> = pat.split('*').collect();
if pattern_parts.len() == 2 {
let prefix = pattern_parts[0];
let suffix = pattern_parts[1];
key.starts_with(prefix) && key.ends_with(suffix)
} else {
key.contains(&pat.replace('*', ""))
}
} else {
key.contains(pat)
}
}
None => true,
};
if matches {
keys.push(key);
returned_keys += 1;
// Stop if we've returned enough keys
if returned_keys >= count {
current_cursor += 1;
break;
}
}
current_cursor += 1;
}
// If we've reached the end of iteration, return cursor 0 to indicate completion
let next_cursor = if returned_keys < count { 0 } else { current_cursor };
Ok((next_cursor, keys))
}
}

View File

@@ -14,7 +14,7 @@ NC='\033[0m' # No Color
# Configuration # Configuration
DB_DIR="./test_db" DB_DIR="./test_db"
PORT=6379 PORT=6381
SERVER_PID="" SERVER_PID=""
# Function to print colored output # Function to print colored output
@@ -288,7 +288,7 @@ main() {
# Build the project # Build the project
print_status "Building HeroDB..." print_status "Building HeroDB..."
if ! cargo build --release; then if ! cargo build -p herodb --release; then
print_error "Failed to build HeroDB" print_error "Failed to build HeroDB"
exit 1 exit 1
fi fi
@@ -298,7 +298,7 @@ main() {
# Start the server # Start the server
print_status "Starting HeroDB server..." print_status "Starting HeroDB server..."
./target/release/redis-rs --dir "$DB_DIR" --port $PORT & ./target/release/herodb --dir "$DB_DIR" --port $PORT &
SERVER_PID=$! SERVER_PID=$!
# Wait for server to start # Wait for server to start