24 Commits

Author SHA1 Message Date
ee163bb6bf ... 2025-08-16 15:16:15 +02:00
84611dd245 ... 2025-08-16 15:10:55 +02:00
200d0c928d ... 2025-08-16 14:22:56 +02:00
30a09e6d53 ... 2025-08-16 13:58:40 +02:00
542996a0ff ... 2025-08-16 13:33:56 +02:00
63ab39b4b1 ... 2025-08-16 11:22:01 +02:00
ee94d731d7 ... 2025-08-16 11:09:18 +02:00
c7945624bd ... 2025-08-16 10:53:48 +02:00
f8dd304820 it works 2025-08-16 10:41:26 +02:00
5eab3b080c ... 2025-08-16 10:28:28 +02:00
246304b9fa ... 2025-08-16 10:10:24 +02:00
074be114c3 ... 2025-08-16 09:55:34 +02:00
e51af83e45 ... 2025-08-16 09:52:36 +02:00
dbd0635cd9 ... 2025-08-16 09:50:56 +02:00
0000d82799 ... 2025-08-16 09:29:18 +02:00
5502ff4bc5 ... 2025-08-16 09:06:33 +02:00
0511dddd99 ... 2025-08-16 08:50:28 +02:00
bec9b20ec7 ... 2025-08-16 08:41:19 +02:00
ad255a9f51 ... 2025-08-16 08:28:52 +02:00
7bcb673361 ... 2025-08-16 08:25:25 +02:00
0f6e595000 ... 2025-08-16 07:54:55 +02:00
d3e28cafe4 ... 2025-08-16 07:23:20 +02:00
de2be4a785 ... 2025-08-16 07:18:55 +02:00
cd61406d1d ... 2025-08-16 06:58:04 +02:00
45 changed files with 7795 additions and 2 deletions

1
.gitattributes vendored Normal file
View File

@@ -0,0 +1 @@
* text=auto

12
.gitignore vendored Normal file
View File

@@ -0,0 +1,12 @@
# Generated by Cargo
# will have compiled files and executables
debug/
target/
.vscode/
# These are backup files generated by rustfmt
**/*.rs.bk
# MSVC Windows builds of rustc generate these, which store debugging information
*.pdb
dumb.rdb

2104
Cargo.lock generated Normal file

File diff suppressed because it is too large Load Diff

38
Cargo.toml Normal file
View File

@@ -0,0 +1,38 @@
[workspace]
resolver = "2"
members = [
"crates/herodb",
"crates/libdbstorage",
"crates/libcrypto",
"crates/libcryptoa",
"crates/herocrypto",
"crates/supervisor",
"crates/supervisorrpc",
]
[workspace.dependencies]
# Common
anyhow = "1.0"
tokio = { version = "1", features = ["full"] }
serde = { version = "1.0", features = ["derive"] }
serde_json = "1.0"
thiserror = "1.0"
log = "0.4"
bytes = "1.3"
# Crypto - Asymmetric
age = "0.10"
secrecy = "0.8"
ed25519-dalek = "2"
base64 = "0.22"
# Crypto - Symmetric & Utilities
chacha20poly1305 = "0.10"
rand = "0.8"
sha2 = "0.10"
# Database
redb = "2.1"
# CLI
clap = { version = "4.5", features = ["derive"] }

21
LICENSE Normal file
View File

@@ -0,0 +1,21 @@
MIT License
Copyright (c) 2024 Pin Fang
Permission is hereby granted, free of charge, to any person obtaining a copy
of this software and associated documentation files (the "Software"), to deal
in the Software without restriction, including without limitation the rights
to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
copies of the Software, and to permit persons to whom the Software is
furnished to do so, subject to the following conditions:
The above copyright notice and this permission notice shall be included in all
copies or substantial portions of the Software.
THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
SOFTWARE.

View File

@@ -1,2 +0,0 @@
# herodb

View File

@@ -0,0 +1,10 @@
[package]
name = "herocrypto"
version = "0.1.0"
edition = "2021"
[dependencies]
redis = { version = "0.24", features = ["tokio-comp"] }
thiserror = { workspace = true }
libcrypto = { path = "../libcrypto" }
libcryptoa = { path = "../libcryptoa" }

View File

@@ -0,0 +1,45 @@
// In crates/herocrypto/src/lib.rs
use redis::{Commands, RedisResult};
use thiserror::Error;
#[derive(Error, Debug)]
pub enum Error {
#[error("Redis connection error: {0}")]
Redis(#[from] redis::RedisError),
#[error("Asymmetric crypto error: {0}")]
Asymmetric(#[from] libcryptoa::AsymmetricCryptoError),
#[error("Key not found in database: {0}")]
KeyNotFound(String),
#[error("Command failed on server: {0}")]
CommandError(String),
}
pub struct HeroCrypto {
// e.g., using a connection manager from redis-rs
client: redis::Client,
}
impl HeroCrypto {
pub fn new(redis_url: &str) -> Result<Self, Error> {
Ok(Self { client: redis::Client::open(redis_url)? })
}
// --- High-level functions to be implemented ---
/// Generates a new keypair and stores it in HeroDB under the given name.
pub async fn generate_keypair(&self, name: &str) -> Result<(), Error> {
let mut con = self.client.get_async_connection().await?;
let (_pub, _priv): (String, String) = redis::cmd("AGE")
.arg("KEYGEN")
.arg(name)
.query_async(&mut con)
.await?;
Ok(())
}
/// Encrypts a message using a key stored in HeroDB.
pub async fn encrypt_by_name(&self, key_name: &str, plaintext: &str) -> Result<String, Error> {
// Implementation will call 'AGE ENCRYPTNAME ...'
unimplemented!()
}
}

31
crates/herodb/Cargo.toml Normal file
View File

@@ -0,0 +1,31 @@
[package]
name = "herodb"
version = "0.1.0"
edition = "2021"
authors = ["Pin Fang <fpfangpin@hotmail.com>"]
[[bin]]
name = "herodb"
path = "src/main.rs"
[dependencies]
# Workspace dependencies
anyhow = { workspace = true }
tokio = { workspace = true }
serde = { workspace = true }
log = { workspace = true }
clap = { workspace = true }
bytes = { workspace = true }
base64 = { workspace = true }
age = { workspace = true }
secrecy = { workspace = true }
ed25519-dalek = { workspace = true }
rand = { workspace = true }
# Local Crate Dependencies
libdbstorage = { path = "../libdbstorage" }
# We will create these libraries in the next steps
libcryptoa = { path = "../libcryptoa" }
[dev-dependencies]
redis = { version = "0.24", features = ["aio", "tokio-comp"] }

267
crates/herodb/README.md Normal file
View File

@@ -0,0 +1,267 @@
# Build Your Own Redis in Rust
This project is to build a toy Redis-Server clone that's capable of parsing Redis protocol and handling basic Redis commands, parsing and initializing Redis from RDB file,
supporting leader-follower replication, redis streams (queue), redis batch commands in transaction.
You can find all the source code and commit history in [my github repo](https://github.com/fangpin/redis-rs).
## Main features
+ Parse Redis protocol
+ Handle basic Redis commands
+ Parse and initialize Redis from RDB file
+ Leader-follower Replication
## Prerequisites
install `redis-cli` first (an implementation of redis client for test purpose)
```sh
cargo install mini-redis
```
Learn about:
- [Redis protocoal](https://redis.io/docs/latest/develop/reference/protocol-spec)
- [RDB file format](https://rdb.fnordig.de/file_format.html)
- [Redis replication](https://redis.io/docs/management/replication/)
## Start the Redis-rs server
```sh
# start as master
cargo run -- --dir /some/db/path --dbfilename dump.rdb
# start as slave
cargo run -- --dir /some/db/path --dbfilename dump.rdb --port 6380 --replicaof "localhost 6379"
```
## Supported Commands
```sh
# basic commands
redis-cli PING
redis-cli ECHO hey
redis-cli SET foo bar
redis-cli SET foo bar px/ex 100
redis-cli GET foo
redis-cli SET foo 2
redis-cli INCR foo
redis-cli INCR missing_key
redis-cli TYPE some_key
redis-cli KEYS "*"
# leader-follower replication related commands
redis-cli CONFIG GET dbfilename
redis-cli INFO replication
# streams related commands
redis-cli XADD stream_key 1526919030474-0 temperature 36 humidity 95
redis-cli XADD stream_key 1526919030474-* temperature 37 humidity 94
redis-cli XADD stream_key "*" foo bar
## read stream
redis-cli XRANGE stream_key 0-2 0-3
## query with + -
redis-cli XRANGE some_key - 1526985054079
## query single stream using xread
redis-cli XREAD streams some_key 1526985054069-0
## query multiple stream using xread
redis-cli XREAD streams stream_key other_stream_key 0-0 0-1
## blocking reads without timeout
redis-cli XREAD block 0 streams some_key 1526985054069-0
# transactions related commands
## start a transaction and exec all queued commands in a transaction
redis-cli
> MULTI
> set foo 1
> incr foo
> exec
## start a transaction and queued commands and cancel transaction then
redis-cli
> MULTI
> set foo 1
> incr foo
> discard
```
## RDB Persistence
Get Redis-rs server config
```sh
redis-cli CONFIG GET dbfilename
```
### RDB file format overview
Here are the different sections of the [RDB file](https://rdb.fnordig.de/file_format.html), in order:
+ Header section
+ Metadata section
+ Database section
+ End of file section
#### Header section
start with some magic number
```sh
52 45 44 49 53 30 30 31 31 // Magic string + version number (ASCII): "REDIS0011".
```
#### Metadata section
contains zero or more "metadata subsections", which each specify a single metadata attribute
e.g.
```sh
FA // Indicates the start of a metadata subsection.
09 72 65 64 69 73 2D 76 65 72 // The name of the metadata attribute (string encoded): "redis-ver".
06 36 2E 30 2E 31 36 // The value of the metadata attribute (string encoded): "6.0.16".
```
#### Database section
contains zero or more "database subsections," which each describe a single database.
e.g.
```sh
FE // Indicates the start of a database subsection.
00 /* The index of the database (size encoded). Here, the index is 0. */
FB // Indicates that hash table size information follows.
03 /* The size of the hash table that stores the keys and values (size encoded). Here, the total key-value hash table size is 3. */
02 /* The size of the hash table that stores the expires of the keys (size encoded). Here, the number of keys with an expiry is 2. */
```
```sh
00 /* The 1-byte flag that specifies the values type and encoding. Here, the flag is 0, which means "string." */
06 66 6F 6F 62 61 72 // The name of the key (string encoded). Here, it's "foobar".
06 62 61 7A 71 75 78 // The value (string encoded). Here, it's "bazqux".
```
```sh
FC /* Indicates that this key ("foo") has an expire, and that the expire timestamp is expressed in milliseconds. */
15 72 E7 07 8F 01 00 00 /* The expire timestamp, expressed in Unix time, stored as an 8-byte unsigned long, in little-endian (read right-to-left). Here, the expire timestamp is 1713824559637. */
00 // Value type is string.
03 66 6F 6F // Key name is "foo".
03 62 61 72 // Value is "bar".
```
```sh
FD /* Indicates that this key ("baz") has an expire, and that the expire timestamp is expressed in seconds. */
52 ED 2A 66 /* The expire timestamp, expressed in Unix time, stored as an 4-byte unsigned integer, in little-endian (read right-to-left). Here, the expire timestamp is 1714089298. */
00 // Value type is string.
03 62 61 7A // Key name is "baz".
03 71 75 78 // Value is "qux".
```
In summary,
- Optional expire information (one of the following):
- Timestamp in seconds:
- FD
- Expire timestamp in seconds (4-byte unsigned integer)
- Timestamp in milliseconds:
- FC
- Expire timestamp in milliseconds (8-byte unsigned long)
- Value type (1-byte flag)
- Key (string encoded)
- Value (encoding depends on value type)
#### End of file section
```sh
FF /* Indicates that the file is ending, and that the checksum follows. */
89 3b b7 4e f8 0f 77 19 // An 8-byte CRC64 checksum of the entire file.
```
#### Size encoding
```sh
/* If the first two bits are 0b00:
The size is the remaining 6 bits of the byte.
In this example, the size is 10: */
0A
00001010
/* If the first two bits are 0b01:
The size is the next 14 bits
(remaining 6 bits in the first byte, combined with the next byte),
in big-endian (read left-to-right).
In this example, the size is 700: */
42 BC
01000010 10111100
/* If the first two bits are 0b10:
Ignore the remaining 6 bits of the first byte.
The size is the next 4 bytes, in big-endian (read left-to-right).
In this example, the size is 17000: */
80 00 00 42 68
10000000 00000000 00000000 01000010 01101000
/* If the first two bits are 0b11:
The remaining 6 bits specify a type of string encoding.
See string encoding section. */
```
#### String encoding
+ The size of the string (size encoded).
+ The string.
```sh
/* The 0x0D size specifies that the string is 13 characters long. The remaining characters spell out "Hello, World!". */
0D 48 65 6C 6C 6F 2C 20 57 6F 72 6C 64 21
```
For sizes that begin with 0b11, the remaining 6 bits indicate a type of string format:
```sh
/* The 0xC0 size indicates the string is an 8-bit integer. In this example, the string is "123". */
C0 7B
/* The 0xC1 size indicates the string is a 16-bit integer. The remaining bytes are in little-endian (read right-to-left). In this example, the string is "12345". */
C1 39 30
/* The 0xC2 size indicates the string is a 32-bit integer. The remaining bytes are in little-endian (read right-to-left), In this example, the string is "1234567". */
C2 87 D6 12 00
/* The 0xC3 size indicates that the string is compressed with the LZF algorithm. You will not encounter LZF-compressed strings in this challenge. */
C3 ...
```
## Replication
Redis server [leader-follower replication](https://redis.io/docs/management/replication/).
Run multiple Redis servers with one acting as the "master" and the others as "replicas". Changes made to the master will be automatically replicated to replicas.
### Send Handshake (follower -> master)
1. When the follower starts, it will send a PING command to the master as RESP Array.
2. Then 2 REPLCONF (replication config) commands are sent to master from follower to communicate the port and the sync protocol. One is *REPLCONF listening-port <PORT>* and the other is *REPLCONF capa psync2*. psnync2 is an example sync protocol supported in this project.
3. The follower sends the *PSYNC* command to master with replication id and offset to start the replication process.
### Receive Handshake (master -> follower)
1. Response a PONG message to follower.
2. Response an OK message to follower for both REPLCONF commands.
3. Response a *+FULLRESYNC <REPL_ID> 0\r\n* to follower with the replication id and offset.
### RDB file transfer
When the follower starts, it sends a *PSYNC ? -1* command to tell master that it doesn't have any data yet, and needs a full resync.
Then the master send a *FULLRESYNC* response to the follower as an acknowledgement.
Finally, the master send the RDB file to represent its current state to the follower. The follower should load the RDB file received to the memory, replacing its current state.
### Receive write commands (master -> follower)
The master sends following write commands to the follower with the offset info.
The sending is to reuse the same TCP connection of handshake and RDB file transfer.
As the all the commands are encoded as RESP Array just like a normal client command, so the follower could reuse the same logic to handler the replicate commands from master. The only difference is the commands are coming from the master and no need response back.
## Streams
A stream is identified by a key, and it contains multiple entries.
Each entry consists of one or more key-value pairs, and is assigned a unique ID.
[More about redis streams](https://redis.io/docs/latest/develop/data-types/streams/)
[Radix tree](https://en.wikipedia.org/wiki/Radix_tree)
It looks like a list of key-value pairs.
```sh
entries:
- id: 1526985054069-0 # (ID of the first entry)
temperature: 36 # (A key value pair in the first entry)
humidity: 95 # (Another key value pair in the first entry)
- id: 1526985054079-0 # (ID of the second entry)
temperature: 37 # (A key value pair in the first entry)
humidity: 94 # (Another key value pair in the first entry)
# ... (and so on)
```
Examples of Redis stream use cases include:
- Event sourcing (e.g., tracking user actions, clicks, etc.)
- Sensor monitoring (e.g., readings from devices in the field)
- Notifications (e.g., storing a record of each user's notifications in a separate stream)
## Transaction
When *MULTI* command is called in a connection, redis just queued all following commands until *EXEC* or *DISCARD* command is called.
*EXEC* command will execute all queued commands and return an array representation of all execution result (including), instead the *DISCARD* command just clear all queued commands.
The transactions among each client connection are independent.

326
crates/herodb/age.rs Normal file
View File

@@ -0,0 +1,326 @@
//! age.rs — AGE (rage) helpers + persistent key management for your mini-Redis.
//
// Features:
// - X25519 encryption/decryption (age style)
// - Ed25519 detached signatures + verification
// - Persistent named keys in DB (strings):
// age:key:{name} -> X25519 recipient (public encryption key, "age1...")
// age:privkey:{name} -> X25519 identity (secret encryption key, "AGE-SECRET-KEY-1...")
// age:signpub:{name} -> Ed25519 verify pubkey (public, used to verify signatures)
// age:signpriv:{name} -> Ed25519 signing secret key (private, used to sign)
// - Base64 wrapping for ciphertext/signature binary blobs.
use crate::protocol::Protocol;
use crate::server::Server;
use libdbstorage::DBError;
use libcryptoa::AsymmetricCryptoError;
// ---------- Storage helpers ----------
fn sget(server: &Server, key: &str) -> Result<Option<String>, DBError> {
let st = server.current_storage()?;
st.get(key)
}
fn sset(server: &Server, key: &str, val: &str) -> Result<(), DBError> {
let st = server.current_storage()?;
st.set(key.to_string(), val.to_string())
}
fn enc_pub_key_key(name: &str) -> String { format!("age:key:{name}") }
fn enc_priv_key_key(name: &str) -> String { format!("age:privkey:{name}") }
fn sign_pub_key_key(name: &str) -> String { format!("age:signpub:{name}") }
fn sign_priv_key_key(name: &str) -> String { format!("age:signpriv:{name}") }
// ---------- Command handlers (RESP Protocol) ----------
// Basic (stateless) ones kept for completeness
pub async fn cmd_age_genenc() -> Protocol {
let (recip, ident) = libcryptoa::gen_enc_keypair();
Protocol::Array(vec![Protocol::BulkString(recip), Protocol::BulkString(ident)])
}
pub async fn cmd_age_gensign() -> Protocol {
let (verify, secret) = libcryptoa::gen_sign_keypair();
Protocol::Array(vec![Protocol::BulkString(verify), Protocol::BulkString(secret)])
}
pub async fn cmd_age_encrypt(recipient: &str, message: &str) -> Protocol {
match libcryptoa::encrypt_b64(recipient, message) {
Ok(b64) => Protocol::BulkString(b64),
Err(e) => Protocol::err(&format!("ERR age: {e}")),
}
}
pub async fn cmd_age_decrypt(identity: &str, ct_b64: &str) -> Protocol {
match libcryptoa::decrypt_b64(identity, ct_b64) {
Ok(pt) => Protocol::BulkString(pt),
Err(e) => Protocol::err(&format!("ERR age: {e}")),
}
}
pub async fn cmd_age_sign(secret: &str, message: &str) -> Protocol {
match libcryptoa::sign_b64(secret, message) {
Ok(b64sig) => Protocol::BulkString(b64sig),
Err(e) => Protocol::err(&format!("ERR age: {e}")),
}
}
pub async fn cmd_age_verify(verify_pub: &str, message: &str, sig_b64: &str) -> Protocol {
match libcryptoa::verify_b64(verify_pub, message, sig_b64) {
Ok(true) => Protocol::SimpleString("1".to_string()),
Ok(false) => Protocol::SimpleString("0".to_string()),
Err(e) => Protocol::err(&format!("ERR age: {e}")),
}
}
// ---------- NEW: Persistent, named-key commands ----------
pub async fn cmd_age_keygen(server: &Server, name: &str) -> Protocol {
let (recip, ident) = libcryptoa::gen_enc_keypair();
if let Err(e) = sset(server, &enc_pub_key_key(name), &recip) { return Protocol::err(&e.0); }
if let Err(e) = sset(server, &enc_priv_key_key(name), &ident) { return Protocol::err(&e.0); }
Protocol::Array(vec![Protocol::BulkString(recip), Protocol::BulkString(ident)])
}
pub async fn cmd_age_signkeygen(server: &Server, name: &str) -> Protocol {
let (verify, secret) = libcryptoa::gen_sign_keypair();
if let Err(e) = sset(server, &sign_pub_key_key(name), &verify) { return Protocol::err(&e.0); }
if let Err(e) = sset(server, &sign_priv_key_key(name), &secret) { return Protocol::err(&e.0); }
Protocol::Array(vec![Protocol::BulkString(verify), Protocol::BulkString(secret)])
}
pub async fn cmd_age_encrypt_name(server: &Server, name: &str, message: &str) -> Protocol {
let recip = match sget(server, &enc_pub_key_key(name)) {
Ok(Some(v)) => v,
Ok(None) => return Protocol::err(&format!("ERR age: missing recipient (age:key:{name})")),
Err(e) => return Protocol::err(&e.0),
};
match libcryptoa::encrypt_b64(&recip, message) {
Ok(ct) => Protocol::BulkString(ct),
Err(e) => Protocol::err(&format!("ERR age: {e}")),
}
}
pub async fn cmd_age_decrypt_name(server: &Server, name: &str, ct_b64: &str) -> Protocol {
let ident = match sget(server, &enc_priv_key_key(name)) {
Ok(Some(v)) => v,
Ok(None) => return Protocol::err(&format!("ERR age: missing identity (age:privkey:{name})")),
Err(e) => return Protocol::err(&e.0),
};
match libcryptoa::decrypt_b64(&ident, ct_b64) {
Ok(pt) => Protocol::BulkString(pt),
Err(e) => Protocol::err(&format!("ERR age: {e}")),
}
}
pub async fn cmd_age_sign_name(server: &Server, name: &str, message: &str) -> Protocol {
let sec = match sget(server, &sign_priv_key_key(name)) {
Ok(Some(v)) => v,
Ok(None) => return Protocol::err(&format!("ERR age: missing signing secret (age:signpriv:{name})")),
Err(e) => return Protocol::err(&e.0),
};
match libcryptoa::sign_b64(&sec, message) {
Ok(sig) => Protocol::BulkString(sig),
Err(e) => Protocol::err(&format!("ERR age: {e}")),
}
}
pub async fn cmd_age_verify_name(server: &Server, name: &str, message: &str, sig_b64: &str) -> Protocol {
let pubk = match sget(server, &sign_pub_key_key(name)) {
Ok(Some(v)) => v,
Ok(None) => return Protocol::err(&format!("ERR age: missing verify pubkey (age:signpub:{name})")),
Err(e) => return Protocol::err(&e.0),
};
match libcryptoa::verify_b64(&pubk, message, sig_b64) {
Ok(true) => Protocol::SimpleString("1".to_string()),
Ok(false) => Protocol::SimpleString("0".to_string()),
Err(e) => Protocol::err(&format!("ERR age: {e}")),
}
}
pub async fn cmd_age_list(server: &Server) -> Protocol {
// Returns 4 arrays: ["encpub", <names...>], ["encpriv", ...], ["signpub", ...], ["signpriv", ...]
let st = match server.current_storage() { Ok(s) => s, Err(e) => return Protocol::err(&e.0) };
let pull = |pat: &str, prefix: &str| -> Result<Vec<String>, DBError> {
let keys = st.keys(pat)?;
let mut names: Vec<String> = keys.into_iter()
.filter_map(|k| k.strip_prefix(prefix).map(|x| x.to_string()))
.collect();
names.sort();
Ok(names)
};
let encpub = match pull("age:key:*", "age:key:") { Ok(v) => v, Err(e)=> return Protocol::err(&e.0) };
let encpriv = match pull("age:privkey:*", "age:privkey:") { Ok(v) => v, Err(e)=> return Protocol::err(&e.0) };
let signpub = match pull("age:signpub:*", "age:signpub:") { Ok(v) => v, Err(e)=> return Protocol::err(&e.0) };
let signpriv= match pull("age:signpriv:*", "age:signpriv:") { Ok(v) => v, Err(e)=> return Protocol::err(&e.0) };
let to_arr = |label: &str, v: Vec<String>| {
let mut out = vec![Protocol::BulkString(label.to_string())];
out.push(Protocol::Array(v.into_iter().map(Protocol::BulkString).collect()));
Protocol::Array(out)
};
Protocol::Array(vec![
to_arr("encpub", encpub),
to_arr("encpriv", encpriv),
to_arr("signpub", signpub),
to_arr("signpriv", signpriv),
])
}
// ---------- Storage helpers ----------
fn sget(server: &Server, key: &str) -> Result<Option<String>, AgeWireError> {
let st = server.current_storage().map_err(|e| AgeWireError::Storage(e.0))?;
st.get(key).map_err(|e| AgeWireError::Storage(e.0))
}
fn sset(server: &Server, key: &str, val: &str) -> Result<(), AgeWireError> {
let st = server.current_storage().map_err(|e| AgeWireError::Storage(e.0))?;
st.set(key.to_string(), val.to_string()).map_err(|e| AgeWireError::Storage(e.0))
}
fn enc_pub_key_key(name: &str) -> String { format!("age:key:{name}") }
fn enc_priv_key_key(name: &str) -> String { format!("age:privkey:{name}") }
fn sign_pub_key_key(name: &str) -> String { format!("age:signpub:{name}") }
fn sign_priv_key_key(name: &str) -> String { format!("age:signpriv:{name}") }
// ---------- Command handlers (RESP Protocol) ----------
// Basic (stateless) ones kept for completeness
pub async fn cmd_age_genenc() -> Protocol {
let (recip, ident) = gen_enc_keypair();
Protocol::Array(vec![Protocol::BulkString(recip), Protocol::BulkString(ident)])
}
pub async fn cmd_age_gensign() -> Protocol {
let (verify, secret) = gen_sign_keypair();
Protocol::Array(vec![Protocol::BulkString(verify), Protocol::BulkString(secret)])
}
pub async fn cmd_age_encrypt(recipient: &str, message: &str) -> Protocol {
match encrypt_b64(recipient, message) {
Ok(b64) => Protocol::BulkString(b64),
Err(e) => e.to_protocol(),
}
}
pub async fn cmd_age_decrypt(identity: &str, ct_b64: &str) -> Protocol {
match decrypt_b64(identity, ct_b64) {
Ok(pt) => Protocol::BulkString(pt),
Err(e) => e.to_protocol(),
}
}
pub async fn cmd_age_sign(secret: &str, message: &str) -> Protocol {
match sign_b64(secret, message) {
Ok(b64sig) => Protocol::BulkString(b64sig),
Err(e) => e.to_protocol(),
}
}
pub async fn cmd_age_verify(verify_pub: &str, message: &str, sig_b64: &str) -> Protocol {
match verify_b64(verify_pub, message, sig_b64) {
Ok(true) => Protocol::SimpleString("1".to_string()),
Ok(false) => Protocol::SimpleString("0".to_string()),
Err(e) => e.to_protocol(),
}
}
// ---------- NEW: Persistent, named-key commands ----------
pub async fn cmd_age_keygen(server: &Server, name: &str) -> Protocol {
let (recip, ident) = gen_enc_keypair();
if let Err(e) = sset(server, &enc_pub_key_key(name), &recip) { return e.to_protocol(); }
if let Err(e) = sset(server, &enc_priv_key_key(name), &ident) { return e.to_protocol(); }
Protocol::Array(vec![Protocol::BulkString(recip), Protocol::BulkString(ident)])
}
pub async fn cmd_age_signkeygen(server: &Server, name: &str) -> Protocol {
let (verify, secret) = gen_sign_keypair();
if let Err(e) = sset(server, &sign_pub_key_key(name), &verify) { return e.to_protocol(); }
if let Err(e) = sset(server, &sign_priv_key_key(name), &secret) { return e.to_protocol(); }
Protocol::Array(vec![Protocol::BulkString(verify), Protocol::BulkString(secret)])
}
pub async fn cmd_age_encrypt_name(server: &Server, name: &str, message: &str) -> Protocol {
let recip = match sget(server, &enc_pub_key_key(name)) {
Ok(Some(v)) => v,
Ok(None) => return AgeWireError::NotFound("recipient (age:key:{name})").to_protocol(),
Err(e) => return e.to_protocol(),
};
match encrypt_b64(&recip, message) {
Ok(ct) => Protocol::BulkString(ct),
Err(e) => e.to_protocol(),
}
}
pub async fn cmd_age_decrypt_name(server: &Server, name: &str, ct_b64: &str) -> Protocol {
let ident = match sget(server, &enc_priv_key_key(name)) {
Ok(Some(v)) => v,
Ok(None) => return AgeWireError::NotFound("identity (age:privkey:{name})").to_protocol(),
Err(e) => return e.to_protocol(),
};
match decrypt_b64(&ident, ct_b64) {
Ok(pt) => Protocol::BulkString(pt),
Err(e) => e.to_protocol(),
}
}
pub async fn cmd_age_sign_name(server: &Server, name: &str, message: &str) -> Protocol {
let sec = match sget(server, &sign_priv_key_key(name)) {
Ok(Some(v)) => v,
Ok(None) => return AgeWireError::NotFound("signing secret (age:signpriv:{name})").to_protocol(),
Err(e) => return e.to_protocol(),
};
match sign_b64(&sec, message) {
Ok(sig) => Protocol::BulkString(sig),
Err(e) => e.to_protocol(),
}
}
pub async fn cmd_age_verify_name(server: &Server, name: &str, message: &str, sig_b64: &str) -> Protocol {
let pubk = match sget(server, &sign_pub_key_key(name)) {
Ok(Some(v)) => v,
Ok(None) => return AgeWireError::NotFound("verify pubkey (age:signpub:{name})").to_protocol(),
Err(e) => return e.to_protocol(),
};
match verify_b64(&pubk, message, sig_b64) {
Ok(true) => Protocol::SimpleString("1".to_string()),
Ok(false) => Protocol::SimpleString("0".to_string()),
Err(e) => e.to_protocol(),
}
}
pub async fn cmd_age_list(server: &Server) -> Protocol {
// Returns 4 arrays: ["encpub", <names...>], ["encpriv", ...], ["signpub", ...], ["signpriv", ...]
let st = match server.current_storage() { Ok(s) => s, Err(e) => return Protocol::err(&e.0) };
let pull = |pat: &str, prefix: &str| -> Result<Vec<String>, DBError> {
let keys = st.keys(pat)?;
let mut names: Vec<String> = keys.into_iter()
.filter_map(|k| k.strip_prefix(prefix).map(|x| x.to_string()))
.collect();
names.sort();
Ok(names)
};
let encpub = match pull("age:key:*", "age:key:") { Ok(v) => v, Err(e)=> return Protocol::err(&e.0) };
let encpriv = match pull("age:privkey:*", "age:privkey:") { Ok(v) => v, Err(e)=> return Protocol::err(&e.0) };
let signpub = match pull("age:signpub:*", "age:signpub:") { Ok(v) => v, Err(e)=> return Protocol::err(&e.0) };
let signpriv= match pull("age:signpriv:*", "age:signpriv:") { Ok(v) => v, Err(e)=> return Protocol::err(&e.0) };
let to_arr = |label: &str, v: Vec<String>| {
let mut out = vec![Protocol::BulkString(label.to_string())];
out.push(Protocol::Array(v.into_iter().map(Protocol::BulkString).collect()));
Protocol::Array(out)
};
Protocol::Array(vec![
to_arr("encpub", encpub),
to_arr("encpriv", encpriv),
to_arr("signpub", signpub),
to_arr("signpriv", signpriv),
])
}

973
crates/herodb/cmd.rs Normal file
View File

@@ -0,0 +1,973 @@
use crate::protocol::Protocol;
use crate::server::Server;
use libdbstorage::DBError;
use libcryptoa;
use serde::Serialize;
#[derive(Debug, Clone)]
pub enum Cmd {
Ping,
Echo(String),
Select(u64), // Changed from u16 to u64
Get(String),
Set(String, String),
SetPx(String, String, u128),
SetEx(String, String, u128),
Keys,
ConfigGet(String),
Info(Option<String>),
Del(String),
Type(String),
Incr(String),
Multi,
Exec,
Discard,
// Hash commands
HSet(String, Vec<(String, String)>),
HGet(String, String),
HGetAll(String),
HDel(String, Vec<String>),
HExists(String, String),
HKeys(String),
HVals(String),
HLen(String),
HMGet(String, Vec<String>),
HSetNx(String, String, String),
HScan(String, u64, Option<String>, Option<u64>), // key, cursor, pattern, count
Scan(u64, Option<String>, Option<u64>), // cursor, pattern, count
Ttl(String),
Exists(String),
Quit,
Client(Vec<String>),
ClientSetName(String),
ClientGetName,
// List commands
LPush(String, Vec<String>),
RPush(String, Vec<String>),
LPop(String, Option<u64>),
RPop(String, Option<u64>),
LLen(String),
LRem(String, i64, String),
LTrim(String, i64, i64),
LIndex(String, i64),
LRange(String, i64, i64),
FlushDb,
Unknow(String),
// AGE (rage) commands — stateless
AgeGenEnc,
AgeGenSign,
AgeEncrypt(String, String), // recipient, message
AgeDecrypt(String, String), // identity, ciphertext_b64
AgeSign(String, String), // signing_secret, message
AgeVerify(String, String, String), // verify_pub, message, signature_b64
// NEW: persistent named-key commands
AgeKeygen(String), // name
AgeSignKeygen(String), // name
AgeEncryptName(String, String), // name, message
AgeDecryptName(String, String), // name, ciphertext_b64
AgeSignName(String, String), // name, message
AgeVerifyName(String, String, String), // name, message, signature_b64
AgeList,
}
impl Cmd {
pub fn from(s: &str) -> Result<(Self, Protocol, &str), DBError> {
let (protocol, remaining) = Protocol::from(s)?;
match protocol.clone() {
Protocol::Array(p) => {
let cmd = p.into_iter().map(|x| x.decode()).collect::<Vec<_>>();
if cmd.is_empty() {
return Err(DBError("cmd length is 0".to_string()));
}
Ok((
match cmd[0].to_lowercase().as_str() {
"select" => {
if cmd.len() != 2 {
return Err(DBError("wrong number of arguments for SELECT".to_string()));
}
let idx = cmd[1].parse::<u64>().map_err(|_| DBError("ERR DB index is not an integer".to_string()))?;
Cmd::Select(idx)
}
"echo" => Cmd::Echo(cmd[1].clone()),
"ping" => Cmd::Ping,
"get" => Cmd::Get(cmd[1].clone()),
"set" => {
if cmd.len() == 5 && cmd[3].to_lowercase() == "px" {
Cmd::SetPx(cmd[1].clone(), cmd[2].clone(), cmd[4].parse().unwrap())
} else if cmd.len() == 5 && cmd[3].to_lowercase() == "ex" {
Cmd::SetEx(cmd[1].clone(), cmd[2].clone(), cmd[4].parse().unwrap())
} else if cmd.len() == 3 {
Cmd::Set(cmd[1].clone(), cmd[2].clone())
} else {
return Err(DBError(format!("unsupported cmd {:?}", cmd)));
}
}
"setex" => {
if cmd.len() != 4 {
return Err(DBError(format!("wrong number of arguments for SETEX command")));
}
Cmd::SetEx(cmd[1].clone(), cmd[3].clone(), cmd[2].parse().unwrap())
}
"config" => {
if cmd.len() != 3 || cmd[1].to_lowercase() != "get" {
return Err(DBError(format!("unsupported cmd {:?}", cmd)));
} else {
Cmd::ConfigGet(cmd[2].clone())
}
}
"keys" => {
if cmd.len() != 2 || cmd[1] != "*" {
return Err(DBError(format!("unsupported cmd {:?}", cmd)));
} else {
Cmd::Keys
}
}
"info" => {
let section = if cmd.len() == 2 {
Some(cmd[1].clone())
} else {
None
};
Cmd::Info(section)
}
"del" => {
if cmd.len() != 2 {
return Err(DBError(format!("unsupported cmd {:?}", cmd)));
}
Cmd::Del(cmd[1].clone())
}
"type" => {
if cmd.len() != 2 {
return Err(DBError(format!("unsupported cmd {:?}", cmd)));
}
Cmd::Type(cmd[1].clone())
}
"incr" => {
if cmd.len() != 2 {
return Err(DBError(format!("unsupported cmd {:?}", cmd)));
}
Cmd::Incr(cmd[1].clone())
}
"multi" => {
if cmd.len() != 1 {
return Err(DBError(format!("unsupported cmd {:?}", cmd)));
}
Cmd::Multi
}
"exec" => {
if cmd.len() != 1 {
return Err(DBError(format!("unsupported cmd {:?}", cmd)));
}
Cmd::Exec
}
"discard" => Cmd::Discard,
// Hash commands
"hset" => {
if cmd.len() < 4 || (cmd.len() - 2) % 2 != 0 {
return Err(DBError(format!("wrong number of arguments for HSET command")));
}
let mut pairs = Vec::new();
let mut i = 2;
while i + 1 < cmd.len() {
pairs.push((cmd[i].clone(), cmd[i + 1].clone()));
i += 2;
}
Cmd::HSet(cmd[1].clone(), pairs)
}
"hget" => {
if cmd.len() != 3 {
return Err(DBError(format!("wrong number of arguments for HGET command")));
}
Cmd::HGet(cmd[1].clone(), cmd[2].clone())
}
"hgetall" => {
if cmd.len() != 2 {
return Err(DBError(format!("wrong number of arguments for HGETALL command")));
}
Cmd::HGetAll(cmd[1].clone())
}
"hdel" => {
if cmd.len() < 3 {
return Err(DBError(format!("wrong number of arguments for HDEL command")));
}
Cmd::HDel(cmd[1].clone(), cmd[2..].to_vec())
}
"hexists" => {
if cmd.len() != 3 {
return Err(DBError(format!("wrong number of arguments for HEXISTS command")));
}
Cmd::HExists(cmd[1].clone(), cmd[2].clone())
}
"hkeys" => {
if cmd.len() != 2 {
return Err(DBError(format!("wrong number of arguments for HKEYS command")));
}
Cmd::HKeys(cmd[1].clone())
}
"hvals" => {
if cmd.len() != 2 {
return Err(DBError(format!("wrong number of arguments for HVALS command")));
}
Cmd::HVals(cmd[1].clone())
}
"hlen" => {
if cmd.len() != 2 {
return Err(DBError(format!("wrong number of arguments for HLEN command")));
}
Cmd::HLen(cmd[1].clone())
}
"hmget" => {
if cmd.len() < 3 {
return Err(DBError(format!("wrong number of arguments for HMGET command")));
}
Cmd::HMGet(cmd[1].clone(), cmd[2..].to_vec())
}
"hsetnx" => {
if cmd.len() != 4 {
return Err(DBError(format!("wrong number of arguments for HSETNX command")));
}
Cmd::HSetNx(cmd[1].clone(), cmd[2].clone(), cmd[3].clone())
}
"hscan" => {
if cmd.len() < 3 {
return Err(DBError(format!("wrong number of arguments for HSCAN command")));
}
let key = cmd[1].clone();
let cursor = cmd[2].parse::<u64>().map_err(|_|
DBError("ERR invalid cursor".to_string()))?;
let mut pattern = None;
let mut count = None;
let mut i = 3;
while i < cmd.len() {
match cmd[i].to_lowercase().as_str() {
"match" => {
if i + 1 >= cmd.len() {
return Err(DBError("ERR syntax error".to_string()));
}
pattern = Some(cmd[i + 1].clone());
i += 2;
}
"count" => {
if i + 1 >= cmd.len() {
return Err(DBError("ERR syntax error".to_string()));
}
count = Some(cmd[i + 1].parse::<u64>().map_err(|_|
DBError("ERR value is not an integer or out of range".to_string()))?);
i += 2;
}
_ => {
return Err(DBError(format!("ERR syntax error")));
}
}
}
Cmd::HScan(key, cursor, pattern, count)
}
"scan" => {
if cmd.len() < 2 {
return Err(DBError(format!("wrong number of arguments for SCAN command")));
}
let cursor = cmd[1].parse::<u64>().map_err(|_|
DBError("ERR invalid cursor".to_string()))?;
let mut pattern = None;
let mut count = None;
let mut i = 2;
while i < cmd.len() {
match cmd[i].to_lowercase().as_str() {
"match" => {
if i + 1 >= cmd.len() {
return Err(DBError("ERR syntax error".to_string()));
}
pattern = Some(cmd[i + 1].clone());
i += 2;
}
"count" => {
if i + 1 >= cmd.len() {
return Err(DBError("ERR syntax error".to_string()));
}
count = Some(cmd[i + 1].parse::<u64>().map_err(|_|
DBError("ERR value is not an integer or out of range".to_string()))?);
i += 2;
}
_ => {
return Err(DBError(format!("ERR syntax error")));
}
}
}
Cmd::Scan(cursor, pattern, count)
}
"ttl" => {
if cmd.len() != 2 {
return Err(DBError(format!("wrong number of arguments for TTL command")));
}
Cmd::Ttl(cmd[1].clone())
}
"exists" => {
if cmd.len() != 2 {
return Err(DBError(format!("wrong number of arguments for EXISTS command")));
}
Cmd::Exists(cmd[1].clone())
}
"quit" => {
if cmd.len() != 1 {
return Err(DBError(format!("wrong number of arguments for QUIT command")));
}
Cmd::Quit
}
"client" => {
if cmd.len() > 1 {
match cmd[1].to_lowercase().as_str() {
"setname" => {
if cmd.len() == 3 {
Cmd::ClientSetName(cmd[2].clone())
} else {
return Err(DBError("wrong number of arguments for 'client setname' command".to_string()));
}
}
"getname" => {
if cmd.len() == 2 {
Cmd::ClientGetName
} else {
return Err(DBError("wrong number of arguments for 'client getname' command".to_string()));
}
}
_ => Cmd::Client(cmd[1..].to_vec()),
}
} else {
Cmd::Client(vec![])
}
}
"lpush" => {
if cmd.len() < 3 {
return Err(DBError(format!("wrong number of arguments for LPUSH command")));
}
Cmd::LPush(cmd[1].clone(), cmd[2..].to_vec())
}
"rpush" => {
if cmd.len() < 3 {
return Err(DBError(format!("wrong number of arguments for RPUSH command")));
}
Cmd::RPush(cmd[1].clone(), cmd[2..].to_vec())
}
"lpop" => {
if cmd.len() < 2 || cmd.len() > 3 {
return Err(DBError(format!("wrong number of arguments for LPOP command")));
}
let count = if cmd.len() == 3 {
Some(cmd[2].parse::<u64>().map_err(|_| DBError("ERR value is not an integer or out of range".to_string()))?)
} else {
None
};
Cmd::LPop(cmd[1].clone(), count)
}
"rpop" => {
if cmd.len() < 2 || cmd.len() > 3 {
return Err(DBError(format!("wrong number of arguments for RPOP command")));
}
let count = if cmd.len() == 3 {
Some(cmd[2].parse::<u64>().map_err(|_| DBError("ERR value is not an integer or out of range".to_string()))?)
} else {
None
};
Cmd::RPop(cmd[1].clone(), count)
}
"llen" => {
if cmd.len() != 2 {
return Err(DBError(format!("wrong number of arguments for LLEN command")));
}
Cmd::LLen(cmd[1].clone())
}
"lrem" => {
if cmd.len() != 4 {
return Err(DBError(format!("wrong number of arguments for LREM command")));
}
let count = cmd[2].parse::<i64>().map_err(|_| DBError("ERR value is not an integer or out of range".to_string()))?;
Cmd::LRem(cmd[1].clone(), count, cmd[3].clone())
}
"ltrim" => {
if cmd.len() != 4 {
return Err(DBError(format!("wrong number of arguments for LTRIM command")));
}
let start = cmd[2].parse::<i64>().map_err(|_| DBError("ERR value is not an integer or out of range".to_string()))?;
let stop = cmd[3].parse::<i64>().map_err(|_| DBError("ERR value is not an integer or out of range".to_string()))?;
Cmd::LTrim(cmd[1].clone(), start, stop)
}
"lindex" => {
if cmd.len() != 3 {
return Err(DBError(format!("wrong number of arguments for LINDEX command")));
}
let index = cmd[2].parse::<i64>().map_err(|_| DBError("ERR value is not an integer or out of range".to_string()))?;
Cmd::LIndex(cmd[1].clone(), index)
}
"lrange" => {
if cmd.len() != 4 {
return Err(DBError(format!("wrong number of arguments for LRANGE command")));
}
let start = cmd[2].parse::<i64>().map_err(|_| DBError("ERR value is not an integer or out of range".to_string()))?;
let stop = cmd[3].parse::<i64>().map_err(|_| DBError("ERR value is not an integer or out of range".to_string()))?;
Cmd::LRange(cmd[1].clone(), start, stop)
}
"flushdb" => {
if cmd.len() != 1 {
return Err(DBError("wrong number of arguments for FLUSHDB command".to_string()));
}
Cmd::FlushDb
}
"age" => {
if cmd.len() < 2 {
return Err(DBError("wrong number of arguments for AGE".to_string()));
}
match cmd[1].to_lowercase().as_str() {
// stateless
"genenc" => { if cmd.len() != 2 { return Err(DBError("AGE GENENC takes no args".to_string())); }
Cmd::AgeGenEnc }
"gensign" => { if cmd.len() != 2 { return Err(DBError("AGE GENSIGN takes no args".to_string())); }
Cmd::AgeGenSign }
"encrypt" => { if cmd.len() != 4 { return Err(DBError("AGE ENCRYPT <recipient> <message>".to_string())); }
Cmd::AgeEncrypt(cmd[2].clone(), cmd[3].clone()) }
"decrypt" => { if cmd.len() != 4 { return Err(DBError("AGE DECRYPT <identity> <ciphertext_b64>".to_string())); }
Cmd::AgeDecrypt(cmd[2].clone(), cmd[3].clone()) }
"sign" => { if cmd.len() != 4 { return Err(DBError("AGE SIGN <signing_secret> <message>".to_string())); }
Cmd::AgeSign(cmd[2].clone(), cmd[3].clone()) }
"verify" => { if cmd.len() != 5 { return Err(DBError("AGE VERIFY <verify_pub> <message> <signature_b64>".to_string())); }
Cmd::AgeVerify(cmd[2].clone(), cmd[3].clone(), cmd[4].clone()) }
// persistent names
"keygen" => { if cmd.len() != 3 { return Err(DBError("AGE KEYGEN <name>".to_string())); }
Cmd::AgeKeygen(cmd[2].clone()) }
"signkeygen" => { if cmd.len() != 3 { return Err(DBError("AGE SIGNKEYGEN <name>".to_string())); }
Cmd::AgeSignKeygen(cmd[2].clone()) }
"encryptname" => { if cmd.len() != 4 { return Err(DBError("AGE ENCRYPTNAME <name> <message>".to_string())); }
Cmd::AgeEncryptName(cmd[2].clone(), cmd[3].clone()) }
"decryptname" => { if cmd.len() != 4 { return Err(DBError("AGE DECRYPTNAME <name> <ciphertext_b64>".to_string())); }
Cmd::AgeDecryptName(cmd[2].clone(), cmd[3].clone()) }
"signname" => { if cmd.len() != 4 { return Err(DBError("AGE SIGNNAME <name> <message>".to_string())); }
Cmd::AgeSignName(cmd[2].clone(), cmd[3].clone()) }
"verifyname" => { if cmd.len() != 5 { return Err(DBError("AGE VERIFYNAME <name> <message> <signature_b64>".to_string())); }
Cmd::AgeVerifyName(cmd[2].clone(), cmd[3].clone(), cmd[4].clone()) }
"list" => { if cmd.len() != 2 { return Err(DBError("AGE LIST".to_string())); }
Cmd::AgeList }
_ => return Err(DBError(format!("unsupported AGE subcommand {:?}", cmd))),
}
}
_ => Cmd::Unknow(cmd[0].clone()),
},
protocol,
remaining
))
}
_ => Err(DBError(format!(
"fail to parse as cmd for {:?}",
protocol
))),
}
}
pub async fn run(self, server: &mut Server) -> Result<Protocol, DBError> {
// Handle queued commands for transactions
if server.queued_cmd.is_some()
&& !matches!(self, Cmd::Exec)
&& !matches!(self, Cmd::Multi)
&& !matches!(self, Cmd::Discard)
{
let protocol = self.clone().to_protocol();
server.queued_cmd.as_mut().unwrap().push((self, protocol));
return Ok(Protocol::SimpleString("QUEUED".to_string()));
}
match self {
Cmd::Select(db) => select_cmd(server, db).await,
Cmd::Ping => Ok(Protocol::SimpleString("PONG".to_string())),
Cmd::Echo(s) => Ok(Protocol::BulkString(s)),
Cmd::Get(k) => get_cmd(server, &k).await,
Cmd::Set(k, v) => set_cmd(server, &k, &v).await,
Cmd::SetPx(k, v, x) => set_px_cmd(server, &k, &v, &x).await,
Cmd::SetEx(k, v, x) => set_ex_cmd(server, &k, &v, &x).await,
Cmd::Del(k) => del_cmd(server, &k).await,
Cmd::ConfigGet(name) => config_get_cmd(&name, server),
Cmd::Keys => keys_cmd(server).await,
Cmd::Info(section) => info_cmd(server, &section).await,
Cmd::Type(k) => type_cmd(server, &k).await,
Cmd::Incr(key) => incr_cmd(server, &key).await,
Cmd::Multi => {
server.queued_cmd = Some(Vec::<(Cmd, Protocol)>::new());
Ok(Protocol::SimpleString("OK".to_string()))
}
Cmd::Exec => exec_cmd(server).await,
Cmd::Discard => {
if server.queued_cmd.is_some() {
server.queued_cmd = None;
Ok(Protocol::SimpleString("OK".to_string()))
} else {
Ok(Protocol::err("ERR DISCARD without MULTI"))
}
}
// Hash commands
Cmd::HSet(key, pairs) => hset_cmd(server, &key, &pairs).await,
Cmd::HGet(key, field) => hget_cmd(server, &key, &field).await,
Cmd::HGetAll(key) => hgetall_cmd(server, &key).await,
Cmd::HDel(key, fields) => hdel_cmd(server, &key, &fields).await,
Cmd::HExists(key, field) => hexists_cmd(server, &key, &field).await,
Cmd::HKeys(key) => hkeys_cmd(server, &key).await,
Cmd::HVals(key) => hvals_cmd(server, &key).await,
Cmd::HLen(key) => hlen_cmd(server, &key).await,
Cmd::HMGet(key, fields) => hmget_cmd(server, &key, &fields).await,
Cmd::HSetNx(key, field, value) => hsetnx_cmd(server, &key, &field, &value).await,
Cmd::HScan(key, cursor, pattern, count) => hscan_cmd(server, &key, &cursor, pattern.as_deref(), &count).await,
Cmd::Scan(cursor, pattern, count) => scan_cmd(server, &cursor, pattern.as_deref(), &count).await,
Cmd::Ttl(key) => ttl_cmd(server, &key).await,
Cmd::Exists(key) => exists_cmd(server, &key).await,
Cmd::Quit => Ok(Protocol::SimpleString("OK".to_string())),
Cmd::Client(_) => Ok(Protocol::SimpleString("OK".to_string())),
Cmd::ClientSetName(name) => client_setname_cmd(server, &name).await,
Cmd::ClientGetName => client_getname_cmd(server).await,
// List commands
Cmd::LPush(key, elements) => lpush_cmd(server, &key, &elements).await,
Cmd::RPush(key, elements) => rpush_cmd(server, &key, &elements).await,
Cmd::LPop(key, count) => lpop_cmd(server, &key, &count).await,
Cmd::RPop(key, count) => rpop_cmd(server, &key, &count).await,
Cmd::LLen(key) => llen_cmd(server, &key).await,
Cmd::LRem(key, count, element) => lrem_cmd(server, &key, count, &element).await,
Cmd::LTrim(key, start, stop) => ltrim_cmd(server, &key, start, stop).await,
Cmd::LIndex(key, index) => lindex_cmd(server, &key, index).await,
Cmd::LRange(key, start, stop) => lrange_cmd(server, &key, start, stop).await,
Cmd::FlushDb => flushdb_cmd(server).await,
// AGE (rage): stateless
Cmd::AgeGenEnc => Ok(libcryptoa::gen_enc_keypair().await),
Cmd::AgeGenSign => Ok(libcryptoa::gen_sign_keypair().await),
Cmd::AgeEncrypt(recipient, message) => Ok(libcryptoa::encrypt_b64(&recipient, &message).await),
Cmd::AgeDecrypt(identity, ct_b64) => Ok(libcryptoa::decrypt_b64(&identity, &ct_b64).await),
Cmd::AgeSign(secret, message) => Ok(libcryptoa::sign_b64(&secret, &message).await),
Cmd::AgeVerify(vpub, msg, sig_b64) => Ok(libcryptoa::verify_b64(&vpub, &msg, &sig_b64).await),
// AGE (rage): persistent named keys
Cmd::AgeKeygen(name) => Ok(crate::age::cmd_age_keygen(server, &name).await),
Cmd::AgeSignKeygen(name) => Ok(crate::age::cmd_age_signkeygen(server, &name).await),
Cmd::AgeEncryptName(name, message) => Ok(crate::age::cmd_age_encrypt_name(server, &name, &message).await),
Cmd::AgeDecryptName(name, ct_b64) => Ok(crate::age::cmd_age_decrypt_name(server, &name, &ct_b64).await),
Cmd::AgeSignName(name, message) => Ok(crate::age::cmd_age_sign_name(server, &name, &message).await),
Cmd::AgeVerifyName(name, message, sig_b64) => Ok(crate::age::cmd_age_verify_name(server, &name, &message, &sig_b64).await),
Cmd::AgeList => Ok(crate::age::cmd_age_list(server).await),
Cmd::Unknow(s) => Ok(Protocol::err(&format!("ERR unknown command `{}`", s))),
}
}
pub fn to_protocol(self) -> Protocol {
match self {
Cmd::Select(db) => Protocol::Array(vec![Protocol::BulkString("select".to_string()), Protocol::BulkString(db.to_string())]),
Cmd::Ping => Protocol::Array(vec![Protocol::BulkString("ping".to_string())]),
Cmd::Echo(s) => Protocol::Array(vec![Protocol::BulkString("echo".to_string()), Protocol::BulkString(s)]),
Cmd::Get(k) => Protocol::Array(vec![Protocol::BulkString("get".to_string()), Protocol::BulkString(k)]),
Cmd::Set(k, v) => Protocol::Array(vec![Protocol::BulkString("set".to_string()), Protocol::BulkString(k), Protocol::BulkString(v)]),
_ => Protocol::SimpleString("...".to_string())
}
}
}
async fn flushdb_cmd(server: &mut Server) -> Result<Protocol, DBError> {
match server.current_storage()?.flushdb() {
Ok(_) => Ok(Protocol::SimpleString("OK".to_string())),
Err(e) => Ok(Protocol::err(&e.0)),
}
}
async fn select_cmd(server: &mut Server, db: u64) -> Result<Protocol, DBError> {
// Test if we can access the database (this will create it if needed)
server.selected_db = db;
match server.current_storage() {
Ok(_) => Ok(Protocol::SimpleString("OK".to_string())),
Err(e) => Ok(Protocol::err(&e.0)),
}
}
async fn lindex_cmd(server: &Server, key: &str, index: i64) -> Result<Protocol, DBError> {
match server.current_storage()?.lindex(key, index) {
Ok(Some(element)) => Ok(Protocol::BulkString(element)),
Ok(None) => Ok(Protocol::Null),
Err(e) => Ok(Protocol::err(&e.0)),
}
}
async fn lrange_cmd(server: &Server, key: &str, start: i64, stop: i64) -> Result<Protocol, DBError> {
match server.current_storage()?.lrange(key, start, stop) {
Ok(elements) => Ok(Protocol::Array(elements.into_iter().map(Protocol::BulkString).collect())),
Err(e) => Ok(Protocol::err(&e.0)),
}
}
async fn ltrim_cmd(server: &Server, key: &str, start: i64, stop: i64) -> Result<Protocol, DBError> {
match server.current_storage()?.ltrim(key, start, stop) {
Ok(_) => Ok(Protocol::SimpleString("OK".to_string())),
Err(e) => Ok(Protocol::err(&e.0)),
}
}
async fn lrem_cmd(server: &Server, key: &str, count: i64, element: &str) -> Result<Protocol, DBError> {
match server.current_storage()?.lrem(key, count, element) {
Ok(removed_count) => Ok(Protocol::SimpleString(removed_count.to_string())),
Err(e) => Ok(Protocol::err(&e.0)),
}
}
async fn llen_cmd(server: &Server, key: &str) -> Result<Protocol, DBError> {
match server.current_storage()?.llen(key) {
Ok(len) => Ok(Protocol::SimpleString(len.to_string())),
Err(e) => Ok(Protocol::err(&e.0)),
}
}
async fn lpop_cmd(server: &Server, key: &str, count: &Option<u64>) -> Result<Protocol, DBError> {
let count_val = count.unwrap_or(1);
match server.current_storage()?.lpop(key, count_val) {
Ok(elements) => {
if elements.is_empty() {
if count.is_some() {
Ok(Protocol::Array(vec![]))
} else {
Ok(Protocol::Null)
}
} else if count.is_some() {
Ok(Protocol::Array(elements.into_iter().map(Protocol::BulkString).collect()))
} else {
Ok(Protocol::BulkString(elements[0].clone()))
}
},
Err(e) => Ok(Protocol::err(&e.0)),
}
}
async fn rpop_cmd(server: &Server, key: &str, count: &Option<u64>) -> Result<Protocol, DBError> {
let count_val = count.unwrap_or(1);
match server.current_storage()?.rpop(key, count_val) {
Ok(elements) => {
if elements.is_empty() {
if count.is_some() {
Ok(Protocol::Array(vec![]))
} else {
Ok(Protocol::Null)
}
} else if count.is_some() {
Ok(Protocol::Array(elements.into_iter().map(Protocol::BulkString).collect()))
} else {
Ok(Protocol::BulkString(elements[0].clone()))
}
},
Err(e) => Ok(Protocol::err(&e.0)),
}
}
async fn lpush_cmd(server: &Server, key: &str, elements: &[String]) -> Result<Protocol, DBError> {
match server.current_storage()?.lpush(key, elements.to_vec()) {
Ok(len) => Ok(Protocol::SimpleString(len.to_string())),
Err(e) => Ok(Protocol::err(&e.0)),
}
}
async fn rpush_cmd(server: &Server, key: &str, elements: &[String]) -> Result<Protocol, DBError> {
match server.current_storage()?.rpush(key, elements.to_vec()) {
Ok(len) => Ok(Protocol::SimpleString(len.to_string())),
Err(e) => Ok(Protocol::err(&e.0)),
}
}
async fn exec_cmd(server: &mut Server) -> Result<Protocol, DBError> {
// Move the queued commands out of `server` so we drop the borrow immediately.
let cmds = if let Some(cmds) = server.queued_cmd.take() {
cmds
} else {
return Ok(Protocol::err("ERR EXEC without MULTI"));
};
let mut out = Vec::new();
for (cmd, _) in cmds {
// Use Box::pin to handle recursion in async function
let res = Box::pin(cmd.run(server)).await?;
out.push(res);
}
Ok(Protocol::Array(out))
}
async fn incr_cmd(server: &Server, key: &String) -> Result<Protocol, DBError> {
let storage = server.current_storage()?;
let current_value = storage.get(key)?;
let new_value = match current_value {
Some(v) => {
match v.parse::<i64>() {
Ok(num) => num + 1,
Err(_) => return Ok(Protocol::err("ERR value is not an integer or out of range")),
}
}
None => 1,
};
storage.set(key.clone(), new_value.to_string())?;
Ok(Protocol::SimpleString(new_value.to_string()))
}
fn config_get_cmd(name: &String, server: &Server) -> Result<Protocol, DBError> {
let value = match name.as_str() {
"dir" => Some(server.option.dir.clone()),
"dbfilename" => Some(format!("{}.db", server.selected_db)),
"databases" => Some("16".to_string()), // Hardcoded as per original logic
_ => None,
};
if let Some(val) = value {
Ok(Protocol::Array(vec![
Protocol::BulkString(name.clone()),
Protocol::BulkString(val),
]))
} else {
// Return an empty array for unknown config options, which is standard Redis behavior
Ok(Protocol::Array(vec![]))
}
}
async fn keys_cmd(server: &Server) -> Result<Protocol, DBError> {
let keys = server.current_storage()?.keys("*")?;
Ok(Protocol::Array(
keys.into_iter().map(Protocol::BulkString).collect(),
))
}
#[derive(Serialize)]
struct ServerInfo {
redis_version: String,
encrypted: bool,
selected_db: u64,
}
async fn info_cmd(server: &Server, section: &Option<String>) -> Result<Protocol, DBError> {
let info = ServerInfo {
redis_version: "7.0.0".to_string(),
encrypted: server.current_storage()?.is_encrypted(),
selected_db: server.selected_db,
};
let mut info_string = String::new();
info_string.push_str(&format!("# Server\n"));
info_string.push_str(&format!("redis_version:{}\n", info.redis_version));
info_string.push_str(&format!("encrypted:{}\n", if info.encrypted { 1 } else { 0 }));
info_string.push_str(&format!("# Keyspace\n"));
info_string.push_str(&format!("db{}:keys=0,expires=0,avg_ttl=0\n", info.selected_db));
match section {
Some(s) => match s.as_str() {
"replication" => Ok(Protocol::BulkString(
"role:master\nmaster_replid:8371b4fb1155b71f4a04d3e1bc3e18c4a990aeea\nmaster_repl_offset:0\n".to_string()
)),
_ => Err(DBError(format!("unsupported section {:?}", s))),
},
None => {
Ok(Protocol::BulkString(info_string))
}
}
}
async fn type_cmd(server: &Server, k: &String) -> Result<Protocol, DBError> {
match server.current_storage()?.get_key_type(k)? {
Some(type_str) => Ok(Protocol::SimpleString(type_str)),
None => Ok(Protocol::SimpleString("none".to_string())),
}
}
async fn del_cmd(server: &Server, k: &str) -> Result<Protocol, DBError> {
server.current_storage()?.del(k.to_string())?;
Ok(Protocol::SimpleString("1".to_string()))
}
async fn set_ex_cmd(
server: &Server,
k: &str,
v: &str,
x: &u128,
) -> Result<Protocol, DBError> {
server.current_storage()?.setx(k.to_string(), v.to_string(), *x * 1000)?;
Ok(Protocol::SimpleString("OK".to_string()))
}
async fn set_px_cmd(
server: &Server,
k: &str,
v: &str,
x: &u128,
) -> Result<Protocol, DBError> {
server.current_storage()?.setx(k.to_string(), v.to_string(), *x)?;
Ok(Protocol::SimpleString("OK".to_string()))
}
async fn set_cmd(server: &Server, k: &str, v: &str) -> Result<Protocol, DBError> {
server.current_storage()?.set(k.to_string(), v.to_string())?;
Ok(Protocol::SimpleString("OK".to_string()))
}
async fn get_cmd(server: &Server, k: &str) -> Result<Protocol, DBError> {
let v = server.current_storage()?.get(k)?;
Ok(v.map_or(Protocol::Null, Protocol::BulkString))
}
// Hash command implementations
async fn hset_cmd(server: &Server, key: &str, pairs: &[(String, String)]) -> Result<Protocol, DBError> {
let new_fields = server.current_storage()?.hset(key, pairs.to_vec())?;
Ok(Protocol::SimpleString(new_fields.to_string()))
}
async fn hget_cmd(server: &Server, key: &str, field: &str) -> Result<Protocol, DBError> {
match server.current_storage()?.hget(key, field) {
Ok(Some(value)) => Ok(Protocol::BulkString(value)),
Ok(None) => Ok(Protocol::Null),
Err(e) => Ok(Protocol::err(&e.0)),
}
}
async fn hgetall_cmd(server: &Server, key: &str) -> Result<Protocol, DBError> {
match server.current_storage()?.hgetall(key) {
Ok(pairs) => {
let mut result = Vec::new();
for (field, value) in pairs {
result.push(Protocol::BulkString(field));
result.push(Protocol::BulkString(value));
}
Ok(Protocol::Array(result))
}
Err(e) => Ok(Protocol::err(&e.0)),
}
}
async fn hdel_cmd(server: &Server, key: &str, fields: &[String]) -> Result<Protocol, DBError> {
match server.current_storage()?.hdel(key, fields.to_vec()) {
Ok(deleted) => Ok(Protocol::SimpleString(deleted.to_string())),
Err(e) => Ok(Protocol::err(&e.0)),
}
}
async fn hexists_cmd(server: &Server, key: &str, field: &str) -> Result<Protocol, DBError> {
match server.current_storage()?.hexists(key, field) {
Ok(exists) => Ok(Protocol::SimpleString(if exists { "1" } else { "0" }.to_string())),
Err(e) => Ok(Protocol::err(&e.0)),
}
}
async fn hkeys_cmd(server: &Server, key: &str) -> Result<Protocol, DBError> {
match server.current_storage()?.hkeys(key) {
Ok(keys) => Ok(Protocol::Array(
keys.into_iter().map(Protocol::BulkString).collect(),
)),
Err(e) => Ok(Protocol::err(&e.0)),
}
}
async fn hvals_cmd(server: &Server, key: &str) -> Result<Protocol, DBError> {
match server.current_storage()?.hvals(key) {
Ok(values) => Ok(Protocol::Array(
values.into_iter().map(Protocol::BulkString).collect(),
)),
Err(e) => Ok(Protocol::err(&e.0)),
}
}
async fn hlen_cmd(server: &Server, key: &str) -> Result<Protocol, DBError> {
match server.current_storage()?.hlen(key) {
Ok(len) => Ok(Protocol::SimpleString(len.to_string())),
Err(e) => Ok(Protocol::err(&e.0)),
}
}
async fn hmget_cmd(server: &Server, key: &str, fields: &[String]) -> Result<Protocol, DBError> {
match server.current_storage()?.hmget(key, fields.to_vec()) {
Ok(values) => {
let result: Vec<Protocol> = values
.into_iter()
.map(|v| v.map_or(Protocol::Null, Protocol::BulkString))
.collect();
Ok(Protocol::Array(result))
}
Err(e) => Ok(Protocol::err(&e.0)),
}
}
async fn hsetnx_cmd(server: &Server, key: &str, field: &str, value: &str) -> Result<Protocol, DBError> {
match server.current_storage()?.hsetnx(key, field, value) {
Ok(was_set) => Ok(Protocol::SimpleString(if was_set { "1" } else { "0" }.to_string())),
Err(e) => Ok(Protocol::err(&e.0)),
}
}
async fn scan_cmd(
server: &Server,
cursor: &u64,
pattern: Option<&str>,
count: &Option<u64>
) -> Result<Protocol, DBError> {
match server.current_storage()?.scan(*cursor, pattern, *count) {
Ok((next_cursor, key_value_pairs)) => {
let mut result = Vec::new();
result.push(Protocol::BulkString(next_cursor.to_string()));
// For SCAN, we only return the keys, not the values
let keys: Vec<Protocol> = key_value_pairs.into_iter().map(|(key, _)| Protocol::BulkString(key)).collect();
result.push(Protocol::Array(keys));
Ok(Protocol::Array(result))
}
Err(e) => Ok(Protocol::err(&format!("ERR {}", e.0))),
}
}
async fn hscan_cmd(
server: &Server,
key: &str,
cursor: &u64,
pattern: Option<&str>,
count: &Option<u64>
) -> Result<Protocol, DBError> {
match server.current_storage()?.hscan(key, *cursor, pattern, *count) {
Ok((next_cursor, field_value_pairs)) => {
let mut result = Vec::new();
result.push(Protocol::BulkString(next_cursor.to_string()));
// For HSCAN, we return field-value pairs flattened
let mut fields_and_values = Vec::new();
for (field, value) in field_value_pairs {
fields_and_values.push(Protocol::BulkString(field));
fields_and_values.push(Protocol::BulkString(value));
}
result.push(Protocol::Array(fields_and_values));
Ok(Protocol::Array(result))
}
Err(e) => Ok(Protocol::err(&format!("ERR {}", e.0))),
}
}
async fn ttl_cmd(server: &Server, key: &str) -> Result<Protocol, DBError> {
match server.current_storage()?.ttl(key) {
Ok(ttl) => Ok(Protocol::SimpleString(ttl.to_string())),
Err(e) => Ok(Protocol::err(&e.0)),
}
}
async fn exists_cmd(server: &Server, key: &str) -> Result<Protocol, DBError> {
match server.current_storage()?.exists(key) {
Ok(exists) => Ok(Protocol::SimpleString(if exists { "1" } else { "0" }.to_string())),
Err(e) => Ok(Protocol::err(&e.0)),
}
}
async fn client_setname_cmd(server: &mut Server, name: &str) -> Result<Protocol, DBError> {
server.client_name = Some(name.to_string());
Ok(Protocol::SimpleString("OK".to_string()))
}
async fn client_getname_cmd(server: &Server) -> Result<Protocol, DBError> {
match &server.client_name {
Some(name) => Ok(Protocol::BulkString(name.clone())),
None => Ok(Protocol::Null),
}
}

View File

@@ -0,0 +1,71 @@
#!/bin/bash
# Start the herodb server in the background
echo "Starting herodb server..."
cargo run -p herodb -- --dir /tmp/herodb_age_test --port 6382 --debug --encryption-key "testkey" &
SERVER_PID=$!
sleep 2 # Give the server a moment to start
REDIS_CLI="redis-cli -p 6382"
echo "--- Generating and Storing Encryption Keys ---"
# The new AGE commands are 'AGE KEYGEN <name>' etc., based on src/cmd.rs
# This script uses older commands like 'AGE.GENERATE_KEYPAIR alice'
# The demo script needs to be updated to match the implemented commands.
# Let's assume the commands in the script are what's expected for now,
# but note this discrepancy. The new commands are AGE KEYGEN etc.
# The script here uses a different syntax not found in src/cmd.rs like 'AGE.GENERATE_KEYPAIR'.
# For now, I will modify the script to fit the actual implementation.
echo "--- Generating and Storing Encryption Keys ---"
$REDIS_CLI AGE KEYGEN alice
$REDIS_CLI AGE KEYGEN bob
echo "--- Encrypting and Decrypting a Message ---"
MESSAGE="Hello, AGE encryption!"
# The new logic stores keys internally and does not expose a command to get the public key.
# We will encrypt by name.
ALICE_PUBKEY_REPLY=$($REDIS_CLI AGE KEYGEN alice | head -n 2 | tail -n 1)
echo "Alice's Public Key: $ALICE_PUBKEY_REPLY"
echo "Encrypting message: '$MESSAGE' with Alice's identity..."
# AGE.ENCRYPT recipient message. But since we use persistent keys, let's use ENCRYPTNAME
CIPHERTEXT=$($REDIS_CLI AGE ENCRYPTNAME alice "$MESSAGE")
echo "Ciphertext: $CIPHERTEXT"
echo "Decrypting ciphertext with Alice's private key..."
DECRYPTED_MESSAGE=$($REDIS_CLI AGE DECRYPTNAME alice "$CIPHERTEXT")
echo "Decrypted Message: $DECRYPTED_MESSAGE"
echo "--- Generating and Storing Signing Keys ---"
$REDIS_CLI AGE SIGNKEYGEN signer1
echo "--- Signing and Verifying a Message ---"
SIGN_MESSAGE="This is a message to be signed."
# Similar to above, we don't have GET_SIGN_PUBKEY. We will verify by name.
echo "Signing message: '$SIGN_MESSAGE' with signer1's private key..."
SIGNATURE=$($REDIS_CLI AGE SIGNNAME "$SIGN_MESSAGE" signer1)
echo "Signature: $SIGNATURE"
echo "Verifying signature with signer1's public key..."
VERIFY_RESULT=$($REDIS_CLI AGE VERIFYNAME signer1 "$SIGN_MESSAGE" "$SIGNATURE")
echo "Verification Result: $VERIFY_RESULT"
# There is no DELETE_KEYPAIR command in the implementation
echo "--- Cleaning up keys (manual in herodb) ---"
# We would use DEL for age:key:alice, etc.
$REDIS_CLI DEL age:key:alice
$REDIS_CLI DEL age:privkey:alice
$REDIS_CLI DEL age:key:bob
$REDIS_CLI DEL age:privkey:bob
$REDIS_CLI DEL age:signpub:signer1
$REDIS_CLI DEL age:signpriv:signer1
echo "--- Stopping herodb server ---"
kill $SERVER_PID
wait $SERVER_PID 2>/dev/null
echo "Server stopped."
echo "Bash demo complete."

View File

@@ -0,0 +1,83 @@
use std::io::{Read, Write};
use std::net::TcpStream;
// Minimal RESP helpers
fn arr(parts: &[&str]) -> String {
let mut out = format!("*{}\r\n", parts.len());
for p in parts {
out.push_str(&format!("${}\r\n{}\r\n", p.len(), p));
}
out
}
fn read_reply(s: &mut TcpStream) -> String {
let mut buf = [0u8; 65536];
let n = s.read(&mut buf).unwrap();
String::from_utf8_lossy(&buf[..n]).to_string()
}
fn parse_two_bulk(reply: &str) -> Option<(String,String)> {
let mut lines = reply.split("\r\n");
if lines.next()? != "*2" { return None; }
let _n = lines.next()?;
let a = lines.next()?.to_string();
let _m = lines.next()?;
let b = lines.next()?.to_string();
Some((a,b))
}
fn parse_bulk(reply: &str) -> Option<String> {
let mut lines = reply.split("\r\n");
let hdr = lines.next()?;
if !hdr.starts_with('$') { return None; }
Some(lines.next()?.to_string())
}
fn parse_simple(reply: &str) -> Option<String> {
let mut lines = reply.split("\r\n");
let hdr = lines.next()?;
if !hdr.starts_with('+') { return None; }
Some(hdr[1..].to_string())
}
fn main() {
let mut args = std::env::args().skip(1);
let host = args.next().unwrap_or_else(|| "127.0.0.1".into());
let port = args.next().unwrap_or_else(|| "6379".into());
let addr = format!("{host}:{port}");
println!("Connecting to {addr}...");
let mut s = TcpStream::connect(addr).expect("connect");
// Generate & persist X25519 enc keys under name "alice"
s.write_all(arr(&["age","keygen","alice"]).as_bytes()).unwrap();
let (_alice_recip, _alice_ident) = parse_two_bulk(&read_reply(&mut s)).expect("gen enc");
// Generate & persist Ed25519 signing key under name "signer"
s.write_all(arr(&["age","signkeygen","signer"]).as_bytes()).unwrap();
let (_verify, _secret) = parse_two_bulk(&read_reply(&mut s)).expect("gen sign");
// Encrypt by name
let msg = "hello from persistent keys";
s.write_all(arr(&["age","encryptname","alice", msg]).as_bytes()).unwrap();
let ct_b64 = parse_bulk(&read_reply(&mut s)).expect("ct b64");
println!("ciphertext b64: {}", ct_b64);
// Decrypt by name
s.write_all(arr(&["age","decryptname","alice", &ct_b64]).as_bytes()).unwrap();
let pt = parse_bulk(&read_reply(&mut s)).expect("pt");
assert_eq!(pt, msg);
println!("decrypted ok");
// Sign by name
s.write_all(arr(&["age","signname","signer", msg]).as_bytes()).unwrap();
let sig_b64 = parse_bulk(&read_reply(&mut s)).expect("sig b64");
// Verify by name
s.write_all(arr(&["age","verifyname","signer", msg, &sig_b64]).as_bytes()).unwrap();
let ok = parse_simple(&read_reply(&mut s)).expect("verify");
assert_eq!(ok, "1");
println!("signature verified");
// List names
s.write_all(arr(&["age","list"]).as_bytes()).unwrap();
let list = read_reply(&mut s);
println!("LIST -> {list}");
println!("✔ persistent AGE workflow complete.");
}

4
crates/herodb/lib.rs Normal file
View File

@@ -0,0 +1,4 @@
pub mod age; // NEW
pub mod cmd;
pub mod protocol;
pub mod server;

81
crates/herodb/main.rs Normal file
View File

@@ -0,0 +1,81 @@
// #![allow(unused_imports)]
use tokio::net::TcpListener;
use herodb::server;
use clap::Parser;
/// Simple program to greet a person
#[derive(Parser, Debug)]
#[command(version, about, long_about = None)]
struct Args {
/// The directory of Redis DB file
#[arg(long)]
dir: String,
/// The port of the Redis server, default is 6379 if not specified
#[arg(long)]
port: Option<u16>,
/// Enable debug mode
#[arg(long)]
debug: bool,
/// Master encryption key for encrypted databases
#[arg(long)]
encryption_key: Option<String>,
/// Encrypt the database
#[arg(long)]
encrypt: bool,
}
#[tokio::main]
async fn main() {
// parse args
let args = Args::parse();
// bind port
let port = args.port.unwrap_or(6379);
println!("will listen on port: {}", port);
let listener = TcpListener::bind(format!("127.0.0.1:{}", port))
.await
.unwrap();
// new DB option
let option = herodb::options::DBOption {
dir: args.dir,
port,
debug: args.debug,
encryption_key: args.encryption_key,
encrypt: args.encrypt,
};
// new server
let server = server::Server::new(option).await;
// Add a small delay to ensure the port is ready
tokio::time::sleep(std::time::Duration::from_millis(100)).await;
// accept new connections
loop {
let stream = listener.accept().await;
match stream {
Ok((stream, _)) => {
println!("accepted new connection");
let mut sc = server.clone();
tokio::spawn(async move {
if let Err(e) = sc.handle(stream).await {
println!("error: {:?}, will close the connection. Bye", e);
}
});
}
Err(e) => {
println!("error: {}", e);
}
}
}
}

8
crates/herodb/options.rs Normal file
View File

@@ -0,0 +1,8 @@
#[derive(Clone)]
pub struct DBOption {
pub dir: String,
pub port: u16,
pub debug: bool,
pub encrypt: bool,
pub encryption_key: Option<String>, // Master encryption key
}

159
crates/herodb/protocol.rs Normal file
View File

@@ -0,0 +1,159 @@
use core::fmt;
use crate::error::DBError;
#[derive(Debug, Clone)]
pub enum Protocol {
SimpleString(String),
BulkString(String),
Null,
Array(Vec<Protocol>),
Error(String), // NEW
}
impl fmt::Display for Protocol {
fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
write!(f, "{}", self.decode().as_str())
}
}
impl Protocol {
pub fn from(protocol: &str) -> Result<(Self, &str), DBError> {
let ret = match protocol.chars().nth(0) {
Some('+') => Self::parse_simple_string_sfx(&protocol[1..]),
Some('$') => Self::parse_bulk_string_sfx(&protocol[1..]),
Some('*') => Self::parse_array_sfx(&protocol[1..]),
_ => Err(DBError(format!(
"[from] unsupported protocol: {:?}",
protocol
))),
};
ret
}
pub fn from_vec(array: Vec<&str>) -> Self {
let array = array
.into_iter()
.map(|x| Protocol::BulkString(x.to_string()))
.collect();
Protocol::Array(array)
}
#[inline]
pub fn ok() -> Self {
Protocol::SimpleString("ok".to_string())
}
#[inline]
pub fn err(msg: &str) -> Self {
Protocol::Error(msg.to_string())
}
#[inline]
pub fn write_on_slave_err() -> Self {
Self::err("DISALLOW WRITE ON SLAVE")
}
#[inline]
pub fn psync_on_slave_err() -> Self {
Self::err("PSYNC ON SLAVE IS NOT ALLOWED")
}
#[inline]
pub fn none() -> Self {
Self::SimpleString("none".to_string())
}
pub fn decode(&self) -> String {
match self {
Protocol::SimpleString(s) => s.to_string(),
Protocol::BulkString(s) => s.to_string(),
Protocol::Null => "".to_string(),
Protocol::Array(s) => s.iter().map(|x| x.decode()).collect::<Vec<_>>().join(" "),
Protocol::Error(s) => s.to_string(),
}
}
pub fn encode(&self) -> String {
match self {
Protocol::SimpleString(s) => format!("+{}\r\n", s),
Protocol::BulkString(s) => format!("${}\r\n{}\r\n", s.len(), s),
Protocol::Array(ss) => {
format!("*{}\r\n", ss.len()) + &ss.iter().map(|x| x.encode()).collect::<String>()
}
Protocol::Null => "$-1\r\n".to_string(),
Protocol::Error(s) => format!("-{}\r\n", s), // proper RESP error
}
}
fn parse_simple_string_sfx(protocol: &str) -> Result<(Self, &str), DBError> {
match protocol.find("\r\n") {
Some(x) => Ok((Self::SimpleString(protocol[..x].to_string()), &protocol[x + 2..])),
_ => Err(DBError(format!(
"[new simple string] unsupported protocol: {:?}",
protocol
))),
}
}
fn parse_bulk_string_sfx(protocol: &str) -> Result<(Self, &str), DBError> {
if let Some(len_end) = protocol.find("\r\n") {
let size = Self::parse_usize(&protocol[..len_end])?;
let data_start = len_end + 2;
let data_end = data_start + size;
let s = Self::parse_string(&protocol[data_start..data_end])?;
if protocol.len() < data_end + 2 || &protocol[data_end..data_end+2] != "\r\n" {
Err(DBError(format!(
"[new bulk string] unmatched string length in prototocl {:?}",
protocol,
)))
} else {
Ok((Protocol::BulkString(s), &protocol[data_end + 2..]))
}
} else {
Err(DBError(format!(
"[new bulk string] unsupported protocol: {:?}",
protocol
)))
}
}
fn parse_array_sfx(s: &str) -> Result<(Self, &str), DBError> {
if let Some(len_end) = s.find("\r\n") {
let array_len = s[..len_end].parse::<usize>()?;
let mut remaining = &s[len_end + 2..];
let mut vec = vec![];
for _ in 0..array_len {
let (p, rem) = Protocol::from(remaining)?;
vec.push(p);
remaining = rem;
}
Ok((Protocol::Array(vec), remaining))
} else {
Err(DBError(format!(
"[new array] unsupported protocol: {:?}",
s
)))
}
}
fn parse_usize(protocol: &str) -> Result<usize, DBError> {
if protocol.is_empty() {
Err(DBError("Cannot parse usize from empty string".to_string()))
} else {
protocol
.parse::<usize>()
.map_err(|_| DBError(format!("Failed to parse usize from: {}", protocol)))
}
}
fn parse_string(protocol: &str) -> Result<String, DBError> {
if protocol.is_empty() {
// Allow empty strings, but handle appropriately
Ok("".to_string())
} else {
Ok(protocol.to_string())
}
}
}

136
crates/herodb/server.rs Normal file
View File

@@ -0,0 +1,136 @@
use core::str;
use std::collections::HashMap;
use std::sync::Arc;
use tokio::io::AsyncReadExt;
use tokio::io::AsyncWriteExt;
use crate::cmd::Cmd;
use crate::error::DBError;
use crate::options;
use crate::protocol::Protocol;
use crate::storage::Storage;
#[derive(Clone)]
pub struct Server {
pub db_cache: std::sync::Arc<std::sync::RwLock<HashMap<u64, Arc<Storage>>>>,
pub option: options::DBOption,
pub client_name: Option<String>,
pub selected_db: u64, // Changed from usize to u64
pub queued_cmd: Option<Vec<(Cmd, Protocol)>>,
}
impl Server {
pub async fn new(option: options::DBOption) -> Self {
Server {
db_cache: Arc::new(std::sync::RwLock::new(HashMap::new())),
option,
client_name: None,
selected_db: 0,
queued_cmd: None,
}
}
pub fn current_storage(&self) -> Result<Arc<libdbstorage::Storage>, libdbstorage::DBError> {
let mut cache = self.db_cache.write().unwrap();
if let Some(storage) = cache.get(&self.selected_db) {
return Ok(storage.clone());
}
// Create new database file
let db_file_path = std::path::PathBuf::from(self.option.dir.clone())
.join(format!("{}.db", self.selected_db));
// Ensure the directory exists before creating the database file
if let Some(parent_dir) = db_file_path.parent() {
std::fs::create_dir_all(parent_dir).map_err(|e| {
DBError(format!("Failed to create directory {}: {}", parent_dir.display(), e))
})?;
}
println!("Creating new db file: {}", db_file_path.display());
let storage = Arc::new(Storage::new(
db_file_path,
self.should_encrypt_db(self.selected_db),
self.option.encryption_key.as_deref()
)?);
cache.insert(self.selected_db, storage.clone());
Ok(storage)
}
fn should_encrypt_db(&self, db_index: u64) -> bool {
// DB 0-9 are non-encrypted, DB 10+ are encrypted
self.option.encrypt && db_index >= 10
}
pub async fn handle(
&mut self,
mut stream: tokio::net::TcpStream,
) -> Result<(), DBError> {
let mut buf = [0; 512];
loop {
let len = match stream.read(&mut buf).await {
Ok(0) => {
println!("[handle] connection closed");
return Ok(());
}
Ok(len) => len,
Err(e) => {
println!("[handle] read error: {:?}", e);
return Err(e.into());
}
};
let mut s = str::from_utf8(&buf[..len])?;
while !s.is_empty() {
let (cmd, protocol, remaining) = match Cmd::from(s) {
Ok((cmd, protocol, remaining)) => (cmd, protocol, remaining),
Err(e) => {
println!("\x1b[31;1mprotocol error: {:?}\x1b[0m", e);
(Cmd::Unknow("protocol_error".to_string()), Protocol::err(&format!("protocol error: {}", e.0)), "")
}
};
s = remaining;
if self.option.debug {
println!("\x1b[34;1mgot command: {:?}, protocol: {:?}\x1b[0m", cmd, protocol);
} else {
println!("got command: {:?}, protocol: {:?}", cmd, protocol);
}
// Check if this is a QUIT command before processing
let is_quit = matches!(cmd, Cmd::Quit);
let res = match cmd.run(self).await {
Ok(p) => p,
Err(e) => {
if self.option.debug {
eprintln!("[run error] {:?}", e);
}
Protocol::err(&format!("ERR {}", e.0))
}
};
if self.option.debug {
println!("\x1b[34;1mqueued cmd {:?}\x1b[0m", self.queued_cmd);
println!("\x1b[32;1mgoing to send response {}\x1b[0m", res.encode());
} else {
print!("queued cmd {:?}", self.queued_cmd);
println!("going to send response {}", res.encode());
}
_ = stream.write(res.encode().as_bytes()).await?;
// If this was a QUIT command, close the connection
if is_quit {
println!("[handle] QUIT command received, closing connection");
return Ok(());
}
}
}
}
}

View File

@@ -0,0 +1 @@
fn main() {}

View File

@@ -0,0 +1,62 @@
use herodb::{server::Server, options::DBOption};
use std::time::Duration;
use tokio::io::{AsyncReadExt, AsyncWriteExt};
use tokio::net::TcpStream;
use tokio::time::sleep;
// Helper function to send command and get response
async fn send_command(stream: &mut TcpStream, command: &str) -> String {
stream.write_all(command.as_bytes()).await.unwrap();
let mut buffer = [0; 1024];
let n = stream.read(&mut buffer).await.unwrap();
String::from_utf8_lossy(&buffer[..n]).to_string()
}
#[tokio::test]
async fn debug_hset_simple() {
// Clean up any existing test database
let test_dir = "/tmp/herodb_debug_hset";
let _ = std::fs::remove_dir_all(test_dir);
std::fs::create_dir_all(test_dir).unwrap();
let port = 16500;
let option = DBOption {
dir: test_dir.to_string(),
port,
debug: false,
encrypt: false,
encryption_key: None,
};
let mut server = Server::new(option).await;
// Start server in background
tokio::spawn(async move {
let listener = tokio::net::TcpListener::bind(format!("127.0.0.1:{}", port))
.await
.unwrap();
loop {
if let Ok((stream, _)) = listener.accept().await {
let _ = server.handle(stream).await;
}
}
});
sleep(Duration::from_millis(200)).await;
let mut stream = TcpStream::connect(format!("127.0.0.1:{}", port)).await.unwrap();
// Test simple HSET
println!("Testing HSET...");
let response = send_command(&mut stream, "*4\r\n$4\r\nHSET\r\n$4\r\nhash\r\n$6\r\nfield1\r\n$6\r\nvalue1\r\n").await;
println!("HSET response: {}", response);
assert!(response.contains("1"), "Expected '1' but got: {}", response);
// Test HGET
println!("Testing HGET...");
let response = send_command(&mut stream, "*3\r\n$4\r\nHGET\r\n$4\r\nhash\r\n$6\r\nfield1\r\n").await;
println!("HGET response: {}", response);
assert!(response.contains("value1"), "Expected 'value1' but got: {}", response);
}

View File

@@ -0,0 +1,56 @@
use herodb::{server::Server, options::DBOption};
use std::time::Duration;
use tokio::io::{AsyncReadExt, AsyncWriteExt};
use tokio::net::TcpStream;
use tokio::time::sleep;
#[tokio::test]
async fn debug_hset_return_value() {
let test_dir = "/tmp/herodb_debug_hset_return";
// Clean up any existing test data
let _ = std::fs::remove_dir_all(&test_dir);
std::fs::create_dir_all(&test_dir).unwrap();
let option = DBOption {
dir: test_dir.to_string(),
port: 16390,
debug: false,
encrypt: false,
encryption_key: None,
};
let mut server = Server::new(option).await;
// Start server in background
tokio::spawn(async move {
let listener = tokio::net::TcpListener::bind("127.0.0.1:16390")
.await
.unwrap();
loop {
if let Ok((stream, _)) = listener.accept().await {
let _ = server.handle(stream).await;
}
}
});
sleep(Duration::from_millis(200)).await;
// Connect and test HSET
let mut stream = TcpStream::connect("127.0.0.1:16390").await.unwrap();
// Send HSET command
let cmd = "*4\r\n$4\r\nHSET\r\n$4\r\nhash\r\n$6\r\nfield1\r\n$6\r\nvalue1\r\n";
stream.write_all(cmd.as_bytes()).await.unwrap();
let mut buffer = [0; 1024];
let n = stream.read(&mut buffer).await.unwrap();
let response = String::from_utf8_lossy(&buffer[..n]);
println!("HSET response: {}", response);
println!("Response bytes: {:?}", &buffer[..n]);
// Check if response contains "1"
assert!(response.contains("1"), "Expected response to contain '1', got: {}", response);
}

View File

@@ -0,0 +1,35 @@
use herodb::protocol::Protocol;
use herodb::cmd::Cmd;
#[test]
fn test_protocol_parsing() {
// Test TYPE command parsing
let type_cmd = "*2\r\n$4\r\nTYPE\r\n$7\r\nnoexist\r\n";
println!("Parsing TYPE command: {}", type_cmd.replace("\r\n", "\\r\\n"));
match Protocol::from(type_cmd) {
Ok((protocol, _)) => {
println!("Protocol parsed successfully: {:?}", protocol);
match Cmd::from(type_cmd) {
Ok((cmd, _, _)) => println!("Command parsed successfully: {:?}", cmd),
Err(e) => println!("Command parsing failed: {:?}", e),
}
}
Err(e) => println!("Protocol parsing failed: {:?}", e),
}
// Test HEXISTS command parsing
let hexists_cmd = "*3\r\n$7\r\nHEXISTS\r\n$4\r\nhash\r\n$7\r\nnoexist\r\n";
println!("\nParsing HEXISTS command: {}", hexists_cmd.replace("\r\n", "\\r\\n"));
match Protocol::from(hexists_cmd) {
Ok((protocol, _)) => {
println!("Protocol parsed successfully: {:?}", protocol);
match Cmd::from(hexists_cmd) {
Ok((cmd, _, _)) => println!("Command parsed successfully: {:?}", cmd),
Err(e) => println!("Command parsing failed: {:?}", e),
}
}
Err(e) => println!("Protocol parsing failed: {:?}", e),
}
}

View File

@@ -0,0 +1,317 @@
use redis::{Client, Commands, Connection, RedisResult};
use std::process::{Child, Command};
use std::time::Duration;
use tokio::time::sleep;
// Helper function to get Redis connection, retrying until successful
fn get_redis_connection(port: u16) -> Connection {
let connection_info = format!("redis://127.0.0.1:{}", port);
let client = Client::open(connection_info).unwrap();
let mut attempts = 0;
loop {
match client.get_connection() {
Ok(mut conn) => {
if redis::cmd("PING").query::<String>(&mut conn).is_ok() {
return conn;
}
}
Err(e) => {
if attempts >= 20 {
panic!(
"Failed to connect to Redis server after 20 attempts: {}",
e
);
}
}
}
attempts += 1;
std::thread::sleep(Duration::from_millis(100));
}
}
// A guard to ensure the server process is killed when it goes out of scope
struct ServerProcessGuard {
process: Child,
test_dir: String,
}
impl Drop for ServerProcessGuard {
fn drop(&mut self) {
println!("Killing server process (pid: {})...", self.process.id());
if let Err(e) = self.process.kill() {
eprintln!("Failed to kill server process: {}", e);
}
match self.process.wait() {
Ok(status) => println!("Server process exited with: {}", status),
Err(e) => eprintln!("Failed to wait on server process: {}", e),
}
// Clean up the specific test directory
println!("Cleaning up test directory: {}", self.test_dir);
if let Err(e) = std::fs::remove_dir_all(&self.test_dir) {
eprintln!("Failed to clean up test directory: {}", e);
}
}
}
// Helper to set up the server and return a connection
fn setup_server() -> (ServerProcessGuard, u16) {
use std::sync::atomic::{AtomicU16, Ordering};
static PORT_COUNTER: AtomicU16 = AtomicU16::new(16400);
let port = PORT_COUNTER.fetch_add(1, Ordering::SeqCst);
let test_dir = format!("/tmp/herodb_test_{}", port);
// Clean up previous test data
if std::path::Path::new(&test_dir).exists() {
let _ = std::fs::remove_dir_all(&test_dir);
}
std::fs::create_dir_all(&test_dir).unwrap();
// Start the server in a subprocess
let child = Command::new("cargo")
.args(&[
"run",
"--",
"--dir",
&test_dir,
"--port",
&port.to_string(),
"--debug",
])
.spawn()
.expect("Failed to start server process");
// Create a new guard that also owns the test directory path
let guard = ServerProcessGuard {
process: child,
test_dir,
};
// Give the server a moment to start
std::thread::sleep(Duration::from_millis(500));
(guard, port)
}
async fn cleanup_keys(conn: &mut Connection) {
let keys: Vec<String> = redis::cmd("KEYS").arg("*").query(conn).unwrap();
if !keys.is_empty() {
for key in keys {
let _: () = redis::cmd("DEL").arg(key).query(conn).unwrap();
}
}
}
#[tokio::test]
async fn all_tests() {
let (_server_guard, port) = setup_server();
let mut conn = get_redis_connection(port);
// Run all tests using the same connection
test_basic_ping(&mut conn).await;
test_string_operations(&mut conn).await;
test_incr_operations(&mut conn).await;
test_hash_operations(&mut conn).await;
test_expiration(&mut conn).await;
test_scan_operations(&mut conn).await;
test_scan_with_count(&mut conn).await;
test_hscan_operations(&mut conn).await;
test_transaction_operations(&mut conn).await;
test_discard_transaction(&mut conn).await;
test_type_command(&mut conn).await;
test_info_command(&mut conn).await;
}
async fn test_basic_ping(conn: &mut Connection) {
cleanup_keys(conn).await;
let result: String = redis::cmd("PING").query(conn).unwrap();
assert_eq!(result, "PONG");
cleanup_keys(conn).await;
}
async fn test_string_operations(conn: &mut Connection) {
cleanup_keys(conn).await;
let _: () = conn.set("key", "value").unwrap();
let result: String = conn.get("key").unwrap();
assert_eq!(result, "value");
let result: Option<String> = conn.get("noexist").unwrap();
assert_eq!(result, None);
let deleted: i32 = conn.del("key").unwrap();
assert_eq!(deleted, 1);
let result: Option<String> = conn.get("key").unwrap();
assert_eq!(result, None);
cleanup_keys(conn).await;
}
async fn test_incr_operations(conn: &mut Connection) {
cleanup_keys(conn).await;
let result: i32 = redis::cmd("INCR").arg("counter").query(conn).unwrap();
assert_eq!(result, 1);
let result: i32 = redis::cmd("INCR").arg("counter").query(conn).unwrap();
assert_eq!(result, 2);
let _: () = conn.set("string", "hello").unwrap();
let result: RedisResult<i32> = redis::cmd("INCR").arg("string").query(conn);
assert!(result.is_err());
cleanup_keys(conn).await;
}
async fn test_hash_operations(conn: &mut Connection) {
cleanup_keys(conn).await;
let result: i32 = conn.hset("hash", "field1", "value1").unwrap();
assert_eq!(result, 1);
let result: String = conn.hget("hash", "field1").unwrap();
assert_eq!(result, "value1");
let _: () = conn.hset("hash", "field2", "value2").unwrap();
let _: () = conn.hset("hash", "field3", "value3").unwrap();
let result: std::collections::HashMap<String, String> = conn.hgetall("hash").unwrap();
assert_eq!(result.len(), 3);
assert_eq!(result.get("field1").unwrap(), "value1");
assert_eq!(result.get("field2").unwrap(), "value2");
assert_eq!(result.get("field3").unwrap(), "value3");
let result: i32 = conn.hlen("hash").unwrap();
assert_eq!(result, 3);
let result: bool = conn.hexists("hash", "field1").unwrap();
assert_eq!(result, true);
let result: bool = conn.hexists("hash", "noexist").unwrap();
assert_eq!(result, false);
let result: i32 = conn.hdel("hash", "field1").unwrap();
assert_eq!(result, 1);
let mut result: Vec<String> = conn.hkeys("hash").unwrap();
result.sort();
assert_eq!(result, vec!["field2", "field3"]);
let mut result: Vec<String> = conn.hvals("hash").unwrap();
result.sort();
assert_eq!(result, vec!["value2", "value3"]);
cleanup_keys(conn).await;
}
async fn test_expiration(conn: &mut Connection) {
cleanup_keys(conn).await;
let _: () = conn.set_ex("expkey", "value", 1).unwrap();
let result: i32 = conn.ttl("expkey").unwrap();
assert!(result == 1 || result == 0);
let result: bool = conn.exists("expkey").unwrap();
assert_eq!(result, true);
sleep(Duration::from_millis(1100)).await;
let result: Option<String> = conn.get("expkey").unwrap();
assert_eq!(result, None);
let result: i32 = conn.ttl("expkey").unwrap();
assert_eq!(result, -2);
let result: bool = conn.exists("expkey").unwrap();
assert_eq!(result, false);
cleanup_keys(conn).await;
}
async fn test_scan_operations(conn: &mut Connection) {
cleanup_keys(conn).await;
for i in 0..5 {
let _: () = conn.set(format!("key{}", i), format!("value{}", i)).unwrap();
}
let result: (u64, Vec<String>) = redis::cmd("SCAN")
.arg(0)
.arg("MATCH")
.arg("key*")
.arg("COUNT")
.arg(10)
.query(conn)
.unwrap();
let (cursor, keys) = result;
assert_eq!(cursor, 0);
assert_eq!(keys.len(), 5);
cleanup_keys(conn).await;
}
async fn test_scan_with_count(conn: &mut Connection) {
cleanup_keys(conn).await;
for i in 0..15 {
let _: () = conn.set(format!("scan_key{}", i), i).unwrap();
}
let mut cursor = 0;
let mut all_keys = std::collections::HashSet::new();
loop {
let (next_cursor, keys): (u64, Vec<String>) = redis::cmd("SCAN")
.arg(cursor)
.arg("MATCH")
.arg("scan_key*")
.arg("COUNT")
.arg(5)
.query(conn)
.unwrap();
for key in keys {
all_keys.insert(key);
}
cursor = next_cursor;
if cursor == 0 {
break;
}
}
assert_eq!(all_keys.len(), 15);
cleanup_keys(conn).await;
}
async fn test_hscan_operations(conn: &mut Connection) {
cleanup_keys(conn).await;
for i in 0..3 {
let _: () = conn.hset("testhash", format!("field{}", i), format!("value{}", i)).unwrap();
}
let result: (u64, Vec<String>) = redis::cmd("HSCAN")
.arg("testhash")
.arg(0)
.arg("MATCH")
.arg("*")
.arg("COUNT")
.arg(10)
.query(conn)
.unwrap();
let (cursor, fields) = result;
assert_eq!(cursor, 0);
assert_eq!(fields.len(), 6);
cleanup_keys(conn).await;
}
async fn test_transaction_operations(conn: &mut Connection) {
cleanup_keys(conn).await;
let _: () = redis::cmd("MULTI").query(conn).unwrap();
let _: () = redis::cmd("SET").arg("key1").arg("value1").query(conn).unwrap();
let _: () = redis::cmd("SET").arg("key2").arg("value2").query(conn).unwrap();
let _: Vec<String> = redis::cmd("EXEC").query(conn).unwrap();
let result: String = conn.get("key1").unwrap();
assert_eq!(result, "value1");
let result: String = conn.get("key2").unwrap();
assert_eq!(result, "value2");
cleanup_keys(conn).await;
}
async fn test_discard_transaction(conn: &mut Connection) {
cleanup_keys(conn).await;
let _: () = redis::cmd("MULTI").query(conn).unwrap();
let _: () = redis::cmd("SET").arg("discard").arg("value").query(conn).unwrap();
let _: () = redis::cmd("DISCARD").query(conn).unwrap();
let result: Option<String> = conn.get("discard").unwrap();
assert_eq!(result, None);
cleanup_keys(conn).await;
}
async fn test_type_command(conn: &mut Connection) {
cleanup_keys(conn).await;
let _: () = conn.set("string", "value").unwrap();
let result: String = redis::cmd("TYPE").arg("string").query(conn).unwrap();
assert_eq!(result, "string");
let _: () = conn.hset("hash", "field", "value").unwrap();
let result: String = redis::cmd("TYPE").arg("hash").query(conn).unwrap();
assert_eq!(result, "hash");
let result: String = redis::cmd("TYPE").arg("noexist").query(conn).unwrap();
assert_eq!(result, "none");
cleanup_keys(conn).await;
}
async fn test_info_command(conn: &mut Connection) {
cleanup_keys(conn).await;
let result: String = redis::cmd("INFO").query(conn).unwrap();
assert!(result.contains("redis_version"));
let result: String = redis::cmd("INFO").arg("replication").query(conn).unwrap();
assert!(result.contains("role:master"));
cleanup_keys(conn).await;
}

View File

@@ -0,0 +1,608 @@
use herodb::{server::Server, options::DBOption};
use std::time::Duration;
use tokio::io::{AsyncReadExt, AsyncWriteExt};
use tokio::net::TcpStream;
use tokio::time::sleep;
// Helper function to start a test server
async fn start_test_server(test_name: &str) -> (Server, u16) {
use std::sync::atomic::{AtomicU16, Ordering};
static PORT_COUNTER: AtomicU16 = AtomicU16::new(16379);
let port = PORT_COUNTER.fetch_add(1, Ordering::SeqCst);
let test_dir = format!("/tmp/herodb_test_{}", test_name);
// Clean up and create test directory
let _ = std::fs::remove_dir_all(&test_dir);
std::fs::create_dir_all(&test_dir).unwrap();
let option = DBOption {
dir: test_dir,
port,
debug: true,
encrypt: false,
encryption_key: None,
};
let server = Server::new(option).await;
(server, port)
}
// Helper function to connect to the test server
async fn connect_to_server(port: u16) -> TcpStream {
let mut attempts = 0;
loop {
match TcpStream::connect(format!("127.0.0.1:{}", port)).await {
Ok(stream) => return stream,
Err(_) if attempts < 10 => {
attempts += 1;
sleep(Duration::from_millis(100)).await;
}
Err(e) => panic!("Failed to connect to test server: {}", e),
}
}
}
// Helper function to send command and get response
async fn send_command(stream: &mut TcpStream, command: &str) -> String {
stream.write_all(command.as_bytes()).await.unwrap();
let mut buffer = [0; 1024];
let n = stream.read(&mut buffer).await.unwrap();
String::from_utf8_lossy(&buffer[..n]).to_string()
}
#[tokio::test]
async fn test_basic_ping() {
let (mut server, port) = start_test_server("ping").await;
// Start server in background
tokio::spawn(async move {
let listener = tokio::net::TcpListener::bind(format!("127.0.0.1:{}", port))
.await
.unwrap();
loop {
if let Ok((stream, _)) = listener.accept().await {
let _ = server.handle(stream).await;
}
}
});
sleep(Duration::from_millis(100)).await;
let mut stream = connect_to_server(port).await;
let response = send_command(&mut stream, "*1\r\n$4\r\nPING\r\n").await;
assert!(response.contains("PONG"));
}
#[tokio::test]
async fn test_string_operations() {
let (mut server, port) = start_test_server("string").await;
// Start server in background
tokio::spawn(async move {
let listener = tokio::net::TcpListener::bind(format!("127.0.0.1:{}", port))
.await
.unwrap();
loop {
if let Ok((stream, _)) = listener.accept().await {
let _ = server.handle(stream).await;
}
}
});
sleep(Duration::from_millis(100)).await;
let mut stream = connect_to_server(port).await;
// Test SET
let response = send_command(&mut stream, "*3\r\n$3\r\nSET\r\n$3\r\nkey\r\n$5\r\nvalue\r\n").await;
assert!(response.contains("OK"));
// Test GET
let response = send_command(&mut stream, "*2\r\n$3\r\nGET\r\n$3\r\nkey\r\n").await;
assert!(response.contains("value"));
// Test GET non-existent key
let response = send_command(&mut stream, "*2\r\n$3\r\nGET\r\n$7\r\nnoexist\r\n").await;
assert!(response.contains("$-1")); // NULL response
// Test DEL
let response = send_command(&mut stream, "*2\r\n$3\r\nDEL\r\n$3\r\nkey\r\n").await;
assert!(response.contains("1"));
// Test GET after DEL
let response = send_command(&mut stream, "*2\r\n$3\r\nGET\r\n$3\r\nkey\r\n").await;
assert!(response.contains("$-1")); // NULL response
}
#[tokio::test]
async fn test_incr_operations() {
let (mut server, port) = start_test_server("incr").await;
tokio::spawn(async move {
let listener = tokio::net::TcpListener::bind(format!("127.0.0.1:{}", port))
.await
.unwrap();
loop {
if let Ok((stream, _)) = listener.accept().await {
let _ = server.handle(stream).await;
}
}
});
sleep(Duration::from_millis(100)).await;
let mut stream = connect_to_server(port).await;
// Test INCR on non-existent key
let response = send_command(&mut stream, "*2\r\n$4\r\nINCR\r\n$7\r\ncounter\r\n").await;
assert!(response.contains("1"));
// Test INCR on existing key
let response = send_command(&mut stream, "*2\r\n$4\r\nINCR\r\n$7\r\ncounter\r\n").await;
assert!(response.contains("2"));
// Test INCR on string value (should fail)
send_command(&mut stream, "*3\r\n$3\r\nSET\r\n$6\r\nstring\r\n$5\r\nhello\r\n").await;
let response = send_command(&mut stream, "*2\r\n$4\r\nINCR\r\n$6\r\nstring\r\n").await;
assert!(response.contains("ERR"));
}
#[tokio::test]
async fn test_hash_operations() {
let (mut server, port) = start_test_server("hash").await;
tokio::spawn(async move {
let listener = tokio::net::TcpListener::bind(format!("127.0.0.1:{}", port))
.await
.unwrap();
loop {
if let Ok((stream, _)) = listener.accept().await {
let _ = server.handle(stream).await;
}
}
});
sleep(Duration::from_millis(100)).await;
let mut stream = connect_to_server(port).await;
// Test HSET
let response = send_command(&mut stream, "*4\r\n$4\r\nHSET\r\n$4\r\nhash\r\n$6\r\nfield1\r\n$6\r\nvalue1\r\n").await;
assert!(response.contains("1")); // 1 new field
// Test HGET
let response = send_command(&mut stream, "*3\r\n$4\r\nHGET\r\n$4\r\nhash\r\n$6\r\nfield1\r\n").await;
assert!(response.contains("value1"));
// Test HSET multiple fields
let response = send_command(&mut stream, "*6\r\n$4\r\nHSET\r\n$4\r\nhash\r\n$6\r\nfield2\r\n$6\r\nvalue2\r\n$6\r\nfield3\r\n$6\r\nvalue3\r\n").await;
assert!(response.contains("2")); // 2 new fields
// Test HGETALL
let response = send_command(&mut stream, "*2\r\n$7\r\nHGETALL\r\n$4\r\nhash\r\n").await;
assert!(response.contains("field1"));
assert!(response.contains("value1"));
assert!(response.contains("field2"));
assert!(response.contains("value2"));
// Test HLEN
let response = send_command(&mut stream, "*2\r\n$4\r\nHLEN\r\n$4\r\nhash\r\n").await;
assert!(response.contains("3"));
// Test HEXISTS
let response = send_command(&mut stream, "*3\r\n$7\r\nHEXISTS\r\n$4\r\nhash\r\n$6\r\nfield1\r\n").await;
assert!(response.contains("1"));
let response = send_command(&mut stream, "*3\r\n$7\r\nHEXISTS\r\n$4\r\nhash\r\n$7\r\nnoexist\r\n").await;
assert!(response.contains("0"));
// Test HDEL
let response = send_command(&mut stream, "*3\r\n$4\r\nHDEL\r\n$4\r\nhash\r\n$6\r\nfield1\r\n").await;
assert!(response.contains("1"));
// Test HKEYS
let response = send_command(&mut stream, "*2\r\n$5\r\nHKEYS\r\n$4\r\nhash\r\n").await;
assert!(response.contains("field2"));
assert!(response.contains("field3"));
assert!(!response.contains("field1")); // Should be deleted
// Test HVALS
let response = send_command(&mut stream, "*2\r\n$5\r\nHVALS\r\n$4\r\nhash\r\n").await;
assert!(response.contains("value2"));
assert!(response.contains("value3"));
}
#[tokio::test]
async fn test_expiration() {
let (mut server, port) = start_test_server("expiration").await;
tokio::spawn(async move {
let listener = tokio::net::TcpListener::bind(format!("127.0.0.1:{}", port))
.await
.unwrap();
loop {
if let Ok((stream, _)) = listener.accept().await {
let _ = server.handle(stream).await;
}
}
});
sleep(Duration::from_millis(100)).await;
let mut stream = connect_to_server(port).await;
// Test SETEX (expire in 1 second)
let response = send_command(&mut stream, "*5\r\n$3\r\nSET\r\n$6\r\nexpkey\r\n$5\r\nvalue\r\n$2\r\nEX\r\n$1\r\n1\r\n").await;
assert!(response.contains("OK"));
// Test TTL
let response = send_command(&mut stream, "*2\r\n$3\r\nTTL\r\n$6\r\nexpkey\r\n").await;
assert!(response.contains("1") || response.contains("0")); // Should be 1 or 0 seconds
// Test EXISTS
let response = send_command(&mut stream, "*2\r\n$6\r\nEXISTS\r\n$6\r\nexpkey\r\n").await;
assert!(response.contains("1"));
// Wait for expiration
sleep(Duration::from_millis(1100)).await;
// Test GET after expiration
let response = send_command(&mut stream, "*2\r\n$3\r\nGET\r\n$6\r\nexpkey\r\n").await;
assert!(response.contains("$-1")); // Should be NULL
// Test TTL after expiration
let response = send_command(&mut stream, "*2\r\n$3\r\nTTL\r\n$6\r\nexpkey\r\n").await;
assert!(response.contains("-2")); // Key doesn't exist
// Test EXISTS after expiration
let response = send_command(&mut stream, "*2\r\n$6\r\nEXISTS\r\n$6\r\nexpkey\r\n").await;
assert!(response.contains("0"));
}
#[tokio::test]
async fn test_scan_operations() {
let (mut server, port) = start_test_server("scan").await;
tokio::spawn(async move {
let listener = tokio::net::TcpListener::bind(format!("127.0.0.1:{}", port))
.await
.unwrap();
loop {
if let Ok((stream, _)) = listener.accept().await {
let _ = server.handle(stream).await;
}
}
});
sleep(Duration::from_millis(100)).await;
let mut stream = connect_to_server(port).await;
// Set up test data
for i in 0..5 {
let cmd = format!("*3\r\n$3\r\nSET\r\n$4\r\nkey{}\r\n$6\r\nvalue{}\r\n", i, i);
send_command(&mut stream, &cmd).await;
}
// Test SCAN
let response = send_command(&mut stream, "*6\r\n$4\r\nSCAN\r\n$1\r\n0\r\n$5\r\nMATCH\r\n$1\r\n*\r\n$5\r\nCOUNT\r\n$2\r\n10\r\n").await;
assert!(response.contains("key"));
// Test KEYS
let response = send_command(&mut stream, "*2\r\n$4\r\nKEYS\r\n$1\r\n*\r\n").await;
assert!(response.contains("key0"));
assert!(response.contains("key1"));
}
#[tokio::test]
async fn test_hscan_operations() {
let (mut server, port) = start_test_server("hscan").await;
tokio::spawn(async move {
let listener = tokio::net::TcpListener::bind(format!("127.0.0.1:{}", port))
.await
.unwrap();
loop {
if let Ok((stream, _)) = listener.accept().await {
let _ = server.handle(stream).await;
}
}
});
sleep(Duration::from_millis(100)).await;
let mut stream = connect_to_server(port).await;
// Set up hash data
for i in 0..3 {
let cmd = format!("*4\r\n$4\r\nHSET\r\n$8\r\ntesthash\r\n$6\r\nfield{}\r\n$6\r\nvalue{}\r\n", i, i);
send_command(&mut stream, &cmd).await;
}
// Test HSCAN
let response = send_command(&mut stream, "*7\r\n$5\r\nHSCAN\r\n$8\r\ntesthash\r\n$1\r\n0\r\n$5\r\nMATCH\r\n$1\r\n*\r\n$5\r\nCOUNT\r\n$2\r\n10\r\n").await;
assert!(response.contains("field"));
assert!(response.contains("value"));
}
#[tokio::test]
async fn test_transaction_operations() {
let (mut server, port) = start_test_server("transaction").await;
tokio::spawn(async move {
let listener = tokio::net::TcpListener::bind(format!("127.0.0.1:{}", port))
.await
.unwrap();
loop {
if let Ok((stream, _)) = listener.accept().await {
let _ = server.handle(stream).await;
}
}
});
sleep(Duration::from_millis(100)).await;
let mut stream = connect_to_server(port).await;
// Test MULTI
let response = send_command(&mut stream, "*1\r\n$5\r\nMULTI\r\n").await;
assert!(response.contains("OK"));
// Test queued commands
let response = send_command(&mut stream, "*3\r\n$3\r\nSET\r\n$4\r\nkey1\r\n$6\r\nvalue1\r\n").await;
assert!(response.contains("QUEUED"));
let response = send_command(&mut stream, "*3\r\n$3\r\nSET\r\n$4\r\nkey2\r\n$6\r\nvalue2\r\n").await;
assert!(response.contains("QUEUED"));
// Test EXEC
let response = send_command(&mut stream, "*1\r\n$4\r\nEXEC\r\n").await;
assert!(response.contains("OK")); // Should contain results of executed commands
// Verify commands were executed
let response = send_command(&mut stream, "*2\r\n$3\r\nGET\r\n$4\r\nkey1\r\n").await;
assert!(response.contains("value1"));
let response = send_command(&mut stream, "*2\r\n$3\r\nGET\r\n$4\r\nkey2\r\n").await;
assert!(response.contains("value2"));
}
#[tokio::test]
async fn test_discard_transaction() {
let (mut server, port) = start_test_server("discard").await;
tokio::spawn(async move {
let listener = tokio::net::TcpListener::bind(format!("127.0.0.1:{}", port))
.await
.unwrap();
loop {
if let Ok((stream, _)) = listener.accept().await {
let _ = server.handle(stream).await;
}
}
});
sleep(Duration::from_millis(100)).await;
let mut stream = connect_to_server(port).await;
// Test MULTI
let response = send_command(&mut stream, "*1\r\n$5\r\nMULTI\r\n").await;
assert!(response.contains("OK"));
// Test queued command
let response = send_command(&mut stream, "*3\r\n$3\r\nSET\r\n$7\r\ndiscard\r\n$5\r\nvalue\r\n").await;
assert!(response.contains("QUEUED"));
// Test DISCARD
let response = send_command(&mut stream, "*1\r\n$7\r\nDISCARD\r\n").await;
assert!(response.contains("OK"));
// Verify command was not executed
let response = send_command(&mut stream, "*2\r\n$3\r\nGET\r\n$7\r\ndiscard\r\n").await;
assert!(response.contains("$-1")); // Should be NULL
}
#[tokio::test]
async fn test_type_command() {
let (mut server, port) = start_test_server("type").await;
tokio::spawn(async move {
let listener = tokio::net::TcpListener::bind(format!("127.0.0.1:{}", port))
.await
.unwrap();
loop {
if let Ok((stream, _)) = listener.accept().await {
let _ = server.handle(stream).await;
}
}
});
sleep(Duration::from_millis(100)).await;
let mut stream = connect_to_server(port).await;
// Test string type
send_command(&mut stream, "*3\r\n$3\r\nSET\r\n$6\r\nstring\r\n$5\r\nvalue\r\n").await;
let response = send_command(&mut stream, "*2\r\n$4\r\nTYPE\r\n$6\r\nstring\r\n").await;
assert!(response.contains("string"));
// Test hash type
send_command(&mut stream, "*4\r\n$4\r\nHSET\r\n$4\r\nhash\r\n$5\r\nfield\r\n$5\r\nvalue\r\n").await;
let response = send_command(&mut stream, "*2\r\n$4\r\nTYPE\r\n$4\r\nhash\r\n").await;
assert!(response.contains("hash"));
// Test non-existent key
let response = send_command(&mut stream, "*2\r\n$4\r\nTYPE\r\n$7\r\nnoexist\r\n").await;
assert!(response.contains("none"));
}
#[tokio::test]
async fn test_config_commands() {
let (mut server, port) = start_test_server("config").await;
tokio::spawn(async move {
let listener = tokio::net::TcpListener::bind(format!("127.0.0.1:{}", port))
.await
.unwrap();
loop {
if let Ok((stream, _)) = listener.accept().await {
let _ = server.handle(stream).await;
}
}
});
sleep(Duration::from_millis(100)).await;
let mut stream = connect_to_server(port).await;
// Test CONFIG GET databases
let response = send_command(&mut stream, "*3\r\n$6\r\nCONFIG\r\n$3\r\nGET\r\n$9\r\ndatabases\r\n").await;
assert!(response.contains("databases"));
assert!(response.contains("16"));
// Test CONFIG GET dir
let response = send_command(&mut stream, "*3\r\n$6\r\nCONFIG\r\n$3\r\nGET\r\n$3\r\ndir\r\n").await;
assert!(response.contains("dir"));
assert!(response.contains("/tmp/herodb_test_config"));
}
#[tokio::test]
async fn test_info_command() {
let (mut server, port) = start_test_server("info").await;
tokio::spawn(async move {
let listener = tokio::net::TcpListener::bind(format!("127.0.0.1:{}", port))
.await
.unwrap();
loop {
if let Ok((stream, _)) = listener.accept().await {
let _ = server.handle(stream).await;
}
}
});
sleep(Duration::from_millis(100)).await;
let mut stream = connect_to_server(port).await;
// Test INFO
let response = send_command(&mut stream, "*1\r\n$4\r\nINFO\r\n").await;
assert!(response.contains("redis_version"));
// Test INFO replication
let response = send_command(&mut stream, "*2\r\n$4\r\nINFO\r\n$11\r\nreplication\r\n").await;
assert!(response.contains("role:master"));
}
#[tokio::test]
async fn test_error_handling() {
let (mut server, port) = start_test_server("error").await;
tokio::spawn(async move {
let listener = tokio::net::TcpListener::bind(format!("127.0.0.1:{}", port))
.await
.unwrap();
loop {
if let Ok((stream, _)) = listener.accept().await {
let _ = server.handle(stream).await;
}
}
});
sleep(Duration::from_millis(100)).await;
let mut stream = connect_to_server(port).await;
// Test WRONGTYPE error - try to use hash command on string
send_command(&mut stream, "*3\r\n$3\r\nSET\r\n$6\r\nstring\r\n$5\r\nvalue\r\n").await;
let response = send_command(&mut stream, "*3\r\n$4\r\nHGET\r\n$6\r\nstring\r\n$5\r\nfield\r\n").await;
assert!(response.contains("WRONGTYPE"));
// Test unknown command
let response = send_command(&mut stream, "*1\r\n$7\r\nUNKNOWN\r\n").await;
assert!(response.contains("unknown cmd") || response.contains("ERR"));
// Test EXEC without MULTI
let response = send_command(&mut stream, "*1\r\n$4\r\nEXEC\r\n").await;
assert!(response.contains("ERR"));
// Test DISCARD without MULTI
let response = send_command(&mut stream, "*1\r\n$7\r\nDISCARD\r\n").await;
assert!(response.contains("ERR"));
}
#[tokio::test]
async fn test_list_operations() {
let (mut server, port) = start_test_server("list").await;
tokio::spawn(async move {
let listener = tokio::net::TcpListener::bind(format!("127.0.0.1:{}", port))
.await
.unwrap();
loop {
if let Ok((stream, _)) = listener.accept().await {
let _ = server.handle(stream).await;
}
}
});
sleep(Duration::from_millis(100)).await;
let mut stream = connect_to_server(port).await;
// Test LPUSH
let response = send_command(&mut stream, "*4\r\n$5\r\nLPUSH\r\n$4\r\nlist\r\n$1\r\na\r\n$1\r\nb\r\n").await;
assert!(response.contains("2")); // 2 elements
// Test RPUSH
let response = send_command(&mut stream, "*4\r\n$5\r\nRPUSH\r\n$4\r\nlist\r\n$1\r\nc\r\n$1\r\nd\r\n").await;
assert!(response.contains("4")); // 4 elements
// Test LLEN
let response = send_command(&mut stream, "*2\r\n$4\r\nLLEN\r\n$4\r\nlist\r\n").await;
assert!(response.contains("4"));
// Test LRANGE
let response = send_command(&mut stream, "*4\r\n$6\r\nLRANGE\r\n$4\r\nlist\r\n$1\r\n0\r\n$2\r\n-1\r\n").await;
assert_eq!(response, "*4\r\n$1\r\nb\r\n$1\r\na\r\n$1\r\nc\r\n$1\r\nd\r\n");
// Test LINDEX
let response = send_command(&mut stream, "*3\r\n$6\r\nLINDEX\r\n$4\r\nlist\r\n$1\r\n0\r\n").await;
assert_eq!(response, "$1\r\nb\r\n");
// Test LPOP
let response = send_command(&mut stream, "*2\r\n$4\r\nLPOP\r\n$4\r\nlist\r\n").await;
assert_eq!(response, "$1\r\nb\r\n");
// Test RPOP
let response = send_command(&mut stream, "*2\r\n$4\r\nRPOP\r\n$4\r\nlist\r\n").await;
assert_eq!(response, "$1\r\nd\r\n");
// Test LREM
send_command(&mut stream, "*3\r\n$5\r\nLPUSH\r\n$4\r\nlist\r\n$1\r\na\r\n").await; // list is now a, c, a
let response = send_command(&mut stream, "*4\r\n$4\r\nLREM\r\n$4\r\nlist\r\n$1\r\n1\r\n$1\r\na\r\n").await;
assert!(response.contains("1"));
// Test LTRIM
let response = send_command(&mut stream, "*4\r\n$5\r\nLTRIM\r\n$4\r\nlist\r\n$1\r\n0\r\n$1\r\n0\r\n").await;
assert!(response.contains("OK"));
let response = send_command(&mut stream, "*2\r\n$4\r\nLLEN\r\n$4\r\nlist\r\n").await;
assert!(response.contains("1"));
}

View File

@@ -0,0 +1,212 @@
use herodb::{server::Server, options::DBOption};
use std::time::Duration;
use tokio::time::sleep;
use tokio::io::{AsyncReadExt, AsyncWriteExt};
use tokio::net::TcpStream;
// Helper function to start a test server with clean data directory
async fn start_test_server(test_name: &str) -> (Server, u16) {
use std::sync::atomic::{AtomicU16, Ordering};
static PORT_COUNTER: AtomicU16 = AtomicU16::new(17000);
// Get a unique port for this test
let port = PORT_COUNTER.fetch_add(1, Ordering::SeqCst);
let test_dir = format!("/tmp/herodb_test_{}", test_name);
// Clean up any existing test data
let _ = std::fs::remove_dir_all(&test_dir);
std::fs::create_dir_all(&test_dir).unwrap();
let option = DBOption {
dir: test_dir,
port,
debug: true,
encrypt: false,
encryption_key: None,
};
let server = Server::new(option).await;
(server, port)
}
// Helper function to send Redis command and get response
async fn send_redis_command(port: u16, command: &str) -> String {
let mut stream = TcpStream::connect(format!("127.0.0.1:{}", port)).await.unwrap();
stream.write_all(command.as_bytes()).await.unwrap();
let mut buffer = [0; 1024];
let n = stream.read(&mut buffer).await.unwrap();
String::from_utf8_lossy(&buffer[..n]).to_string()
}
#[tokio::test]
async fn test_basic_redis_functionality() {
let (mut server, port) = start_test_server("basic").await;
// Start server in background with timeout
let server_handle = tokio::spawn(async move {
let listener = tokio::net::TcpListener::bind(format!("127.0.0.1:{}", port))
.await
.unwrap();
// Accept only a few connections for testing
for _ in 0..10 {
if let Ok((stream, _)) = listener.accept().await {
let _ = server.handle(stream).await;
}
}
});
sleep(Duration::from_millis(100)).await;
// Test PING
let response = send_redis_command(port, "*1\r\n$4\r\nPING\r\n").await;
assert!(response.contains("PONG"));
// Test SET
let response = send_redis_command(port, "*3\r\n$3\r\nSET\r\n$3\r\nkey\r\n$5\r\nvalue\r\n").await;
assert!(response.contains("OK"));
// Test GET
let response = send_redis_command(port, "*2\r\n$3\r\nGET\r\n$3\r\nkey\r\n").await;
assert!(response.contains("value"));
// Test HSET
let response = send_redis_command(port, "*4\r\n$4\r\nHSET\r\n$4\r\nhash\r\n$5\r\nfield\r\n$5\r\nvalue\r\n").await;
assert!(response.contains("1"));
// Test HGET
let response = send_redis_command(port, "*3\r\n$4\r\nHGET\r\n$4\r\nhash\r\n$5\r\nfield\r\n").await;
assert!(response.contains("value"));
// Test EXISTS
let response = send_redis_command(port, "*2\r\n$6\r\nEXISTS\r\n$3\r\nkey\r\n").await;
assert!(response.contains("1"));
// Test TTL
let response = send_redis_command(port, "*2\r\n$3\r\nTTL\r\n$3\r\nkey\r\n").await;
assert!(response.contains("-1")); // No expiration
// Test TYPE
let response = send_redis_command(port, "*2\r\n$4\r\nTYPE\r\n$3\r\nkey\r\n").await;
assert!(response.contains("string"));
// Test QUIT to close connection gracefully
let response = send_redis_command(port, "*1\r\n$4\r\nQUIT\r\n").await;
assert!(response.contains("OK"));
// Stop the server
server_handle.abort();
println!("✅ All basic Redis functionality tests passed!");
}
#[tokio::test]
async fn test_hash_operations() {
let (mut server, port) = start_test_server("hash_ops").await;
// Start server in background with timeout
let server_handle = tokio::spawn(async move {
let listener = tokio::net::TcpListener::bind(format!("127.0.0.1:{}", port))
.await
.unwrap();
// Accept only a few connections for testing
for _ in 0..5 {
if let Ok((stream, _)) = listener.accept().await {
let _ = server.handle(stream).await;
}
}
});
sleep(Duration::from_millis(100)).await;
// Test HSET multiple fields
let response = send_redis_command(port, "*6\r\n$4\r\nHSET\r\n$4\r\nhash\r\n$6\r\nfield1\r\n$6\r\nvalue1\r\n$6\r\nfield2\r\n$6\r\nvalue2\r\n").await;
assert!(response.contains("2")); // 2 new fields
// Test HGETALL
let response = send_redis_command(port, "*2\r\n$7\r\nHGETALL\r\n$4\r\nhash\r\n").await;
assert!(response.contains("field1"));
assert!(response.contains("value1"));
assert!(response.contains("field2"));
assert!(response.contains("value2"));
// Test HEXISTS
let response = send_redis_command(port, "*3\r\n$7\r\nHEXISTS\r\n$4\r\nhash\r\n$6\r\nfield1\r\n").await;
assert!(response.contains("1"));
// Test HLEN
let response = send_redis_command(port, "*2\r\n$4\r\nHLEN\r\n$4\r\nhash\r\n").await;
assert!(response.contains("2"));
// Test HSCAN
let response = send_redis_command(port, "*7\r\n$5\r\nHSCAN\r\n$4\r\nhash\r\n$1\r\n0\r\n$5\r\nMATCH\r\n$1\r\n*\r\n$5\r\nCOUNT\r\n$2\r\n10\r\n").await;
assert!(response.contains("field1"));
assert!(response.contains("value1"));
assert!(response.contains("field2"));
assert!(response.contains("value2"));
// Stop the server
server_handle.abort();
println!("✅ All hash operations tests passed!");
}
#[tokio::test]
async fn test_transaction_operations() {
let (mut server, port) = start_test_server("transactions").await;
// Start server in background with timeout
let server_handle = tokio::spawn(async move {
let listener = tokio::net::TcpListener::bind(format!("127.0.0.1:{}", port))
.await
.unwrap();
// Accept only a few connections for testing
for _ in 0..5 {
if let Ok((stream, _)) = listener.accept().await {
let _ = server.handle(stream).await;
}
}
});
sleep(Duration::from_millis(100)).await;
// Use a single connection for the transaction
let mut stream = TcpStream::connect(format!("127.0.0.1:{}", port)).await.unwrap();
// Test MULTI
stream.write_all("*1\r\n$5\r\nMULTI\r\n".as_bytes()).await.unwrap();
let mut buffer = [0; 1024];
let n = stream.read(&mut buffer).await.unwrap();
let response = String::from_utf8_lossy(&buffer[..n]);
assert!(response.contains("OK"));
// Test queued commands
stream.write_all("*3\r\n$3\r\nSET\r\n$4\r\nkey1\r\n$6\r\nvalue1\r\n".as_bytes()).await.unwrap();
let n = stream.read(&mut buffer).await.unwrap();
let response = String::from_utf8_lossy(&buffer[..n]);
assert!(response.contains("QUEUED"));
stream.write_all("*3\r\n$3\r\nSET\r\n$4\r\nkey2\r\n$6\r\nvalue2\r\n".as_bytes()).await.unwrap();
let n = stream.read(&mut buffer).await.unwrap();
let response = String::from_utf8_lossy(&buffer[..n]);
assert!(response.contains("QUEUED"));
// Test EXEC
stream.write_all("*1\r\n$4\r\nEXEC\r\n".as_bytes()).await.unwrap();
let n = stream.read(&mut buffer).await.unwrap();
let response = String::from_utf8_lossy(&buffer[..n]);
assert!(response.contains("OK")); // Should contain array of OK responses
// Verify commands were executed
let response = send_redis_command(port, "*2\r\n$3\r\nGET\r\n$4\r\nkey1\r\n").await;
assert!(response.contains("value1"));
// Stop the server
server_handle.abort();
println!("✅ All transaction operations tests passed!");
}

View File

@@ -0,0 +1,183 @@
use herodb::{server::Server, options::DBOption};
use std::time::Duration;
use tokio::io::{AsyncReadExt, AsyncWriteExt};
use tokio::net::TcpStream;
use tokio::time::sleep;
// Helper function to start a test server with clean data directory
async fn start_test_server(test_name: &str) -> (Server, u16) {
use std::sync::atomic::{AtomicU16, Ordering};
static PORT_COUNTER: AtomicU16 = AtomicU16::new(16500);
let port = PORT_COUNTER.fetch_add(1, Ordering::SeqCst);
let test_dir = format!("/tmp/herodb_simple_test_{}", test_name);
// Clean up any existing test data
let _ = std::fs::remove_dir_all(&test_dir);
std::fs::create_dir_all(&test_dir).unwrap();
let option = DBOption {
dir: test_dir,
port,
debug: false,
encrypt: false,
encryption_key: None,
};
let server = Server::new(option).await;
(server, port)
}
// Helper function to send command and get response
async fn send_command(stream: &mut TcpStream, command: &str) -> String {
stream.write_all(command.as_bytes()).await.unwrap();
let mut buffer = [0; 1024];
let n = stream.read(&mut buffer).await.unwrap();
String::from_utf8_lossy(&buffer[..n]).to_string()
}
// Helper function to connect to the test server
async fn connect_to_server(port: u16) -> TcpStream {
let mut attempts = 0;
loop {
match TcpStream::connect(format!("127.0.0.1:{}", port)).await {
Ok(stream) => return stream,
Err(_) if attempts < 10 => {
attempts += 1;
sleep(Duration::from_millis(100)).await;
}
Err(e) => panic!("Failed to connect to test server: {}", e),
}
}
}
#[tokio::test]
async fn test_basic_ping_simple() {
let (mut server, port) = start_test_server("ping").await;
// Start server in background
tokio::spawn(async move {
let listener = tokio::net::TcpListener::bind(format!("127.0.0.1:{}", port))
.await
.unwrap();
loop {
if let Ok((stream, _)) = listener.accept().await {
let _ = server.handle(stream).await;
}
}
});
sleep(Duration::from_millis(200)).await;
let mut stream = connect_to_server(port).await;
let response = send_command(&mut stream, "*1\r\n$4\r\nPING\r\n").await;
assert!(response.contains("PONG"));
}
#[tokio::test]
async fn test_hset_clean_db() {
let (mut server, port) = start_test_server("hset_clean").await;
// Start server in background
tokio::spawn(async move {
let listener = tokio::net::TcpListener::bind(format!("127.0.0.1:{}", port))
.await
.unwrap();
loop {
if let Ok((stream, _)) = listener.accept().await {
let _ = server.handle(stream).await;
}
}
});
sleep(Duration::from_millis(200)).await;
let mut stream = connect_to_server(port).await;
// Test HSET - should return 1 for new field
let response = send_command(&mut stream, "*4\r\n$4\r\nHSET\r\n$4\r\nhash\r\n$6\r\nfield1\r\n$6\r\nvalue1\r\n").await;
println!("HSET response: {}", response);
assert!(response.contains("1"), "Expected HSET to return 1, got: {}", response);
// Test HGET
let response = send_command(&mut stream, "*3\r\n$4\r\nHGET\r\n$4\r\nhash\r\n$6\r\nfield1\r\n").await;
println!("HGET response: {}", response);
assert!(response.contains("value1"));
}
#[tokio::test]
async fn test_type_command_simple() {
let (mut server, port) = start_test_server("type").await;
// Start server in background
tokio::spawn(async move {
let listener = tokio::net::TcpListener::bind(format!("127.0.0.1:{}", port))
.await
.unwrap();
loop {
if let Ok((stream, _)) = listener.accept().await {
let _ = server.handle(stream).await;
}
}
});
sleep(Duration::from_millis(200)).await;
let mut stream = connect_to_server(port).await;
// Test string type
send_command(&mut stream, "*3\r\n$3\r\nSET\r\n$6\r\nstring\r\n$5\r\nvalue\r\n").await;
let response = send_command(&mut stream, "*2\r\n$4\r\nTYPE\r\n$6\r\nstring\r\n").await;
println!("TYPE string response: {}", response);
assert!(response.contains("string"));
// Test hash type
send_command(&mut stream, "*4\r\n$4\r\nHSET\r\n$4\r\nhash\r\n$5\r\nfield\r\n$5\r\nvalue\r\n").await;
let response = send_command(&mut stream, "*2\r\n$4\r\nTYPE\r\n$4\r\nhash\r\n").await;
println!("TYPE hash response: {}", response);
assert!(response.contains("hash"));
// Test non-existent key
let response = send_command(&mut stream, "*2\r\n$4\r\nTYPE\r\n$7\r\nnoexist\r\n").await;
println!("TYPE noexist response: {}", response);
assert!(response.contains("none"), "Expected 'none' for non-existent key, got: {}", response);
}
#[tokio::test]
async fn test_hexists_simple() {
let (mut server, port) = start_test_server("hexists").await;
// Start server in background
tokio::spawn(async move {
let listener = tokio::net::TcpListener::bind(format!("127.0.0.1:{}", port))
.await
.unwrap();
loop {
if let Ok((stream, _)) = listener.accept().await {
let _ = server.handle(stream).await;
}
}
});
sleep(Duration::from_millis(200)).await;
let mut stream = connect_to_server(port).await;
// Set up hash
send_command(&mut stream, "*4\r\n$4\r\nHSET\r\n$4\r\nhash\r\n$6\r\nfield1\r\n$6\r\nvalue1\r\n").await;
// Test HEXISTS for existing field
let response = send_command(&mut stream, "*3\r\n$7\r\nHEXISTS\r\n$4\r\nhash\r\n$6\r\nfield1\r\n").await;
println!("HEXISTS existing field response: {}", response);
assert!(response.contains("1"));
// Test HEXISTS for non-existent field
let response = send_command(&mut stream, "*3\r\n$7\r\nHEXISTS\r\n$4\r\nhash\r\n$7\r\nnoexist\r\n").await;
println!("HEXISTS non-existent field response: {}", response);
assert!(response.contains("0"), "Expected HEXISTS to return 0 for non-existent field, got: {}", response);
}

View File

@@ -0,0 +1,10 @@
[package]
name = "libcrypto"
version = "0.1.0"
edition = "2021"
[dependencies]
chacha20poly1305 = { workspace = true }
rand = { workspace = true }
sha2 = { workspace = true }
thiserror = { workspace = true }

View File

@@ -0,0 +1,72 @@
// In crates/libcrypto/src/lib.rs
use chacha20poly1305::{
aead::{Aead, KeyInit, OsRng},
XChaCha20Poly1305, XNonce,
};
use rand::RngCore;
use sha2::{Digest, Sha256};
use thiserror::Error;
const VERSION: u8 = 1;
const NONCE_LEN: usize = 24;
const TAG_LEN: usize = 16;
#[derive(Error, Debug)]
pub enum CryptoError {
#[error("invalid format: data too short")]
Format,
#[error("unknown version: {0}")]
Version(u8),
#[error("decryption failed: wrong key or corrupted data")]
Decrypt,
}
/// Super-simple factory: new(secret) + encrypt(bytes) + decrypt(bytes)
pub struct CryptoFactory {
key: chacha20poly1305::Key,
}
impl CryptoFactory {
/// Accepts any secret bytes; turns them into a 32-byte key (SHA-256).
pub fn new<S: AsRef<[u8]>>(secret: S) -> Self {
let mut h = Sha256::new();
h.update(b"xchacha20poly1305-factory:v1"); // domain separation
h.update(secret.as_ref());
let digest = h.finalize(); // 32 bytes
let key = chacha20poly1305::Key::from_slice(&digest).to_owned();
Self { key }
}
/// Output layout: [version:1][nonce:24][ciphertext||tag]
pub fn encrypt(&self, plaintext: &[u8]) -> Vec<u8> {
let cipher = XChaCha20Poly1305::new(&self.key);
let mut nonce_bytes = [0u8; NONCE_LEN];
OsRng.fill_bytes(&mut nonce_bytes);
let nonce = XNonce::from_slice(&nonce_bytes);
let mut out = Vec::with_capacity(1 + NONCE_LEN + plaintext.len() + TAG_LEN);
out.push(VERSION);
out.extend_from_slice(&nonce_bytes);
let ct = cipher.encrypt(nonce, plaintext).expect("encrypt");
out.extend_from_slice(&ct);
out
}
pub fn decrypt(&self, blob: &[u8]) -> Result<Vec<u8>, CryptoError> {
if blob.len() < 1 + NONCE_LEN + TAG_LEN {
return Err(CryptoError::Format);
}
let ver = blob[0];
if ver != VERSION {
return Err(CryptoError::Version(ver));
}
let nonce = XNonce::from_slice(&blob[1..1 + NONCE_LEN]);
let ct = &blob[1 + NONCE_LEN..];
let cipher = XChaCha20Poly1305::new(&self.key);
cipher.decrypt(nonce, ct).map_err(|_| CryptoError::Decrypt)
}
}

View File

@@ -0,0 +1,12 @@
[package]
name = "libcryptoa"
version = "0.1.0"
edition = "2021"
[dependencies]
age = { workspace = true }
secrecy = { workspace = true }
ed25519-dalek = { workspace = true }
base64 = { workspace = true }
rand = { workspace = true }
thiserror = { workspace = true }

View File

@@ -0,0 +1,100 @@
// In crates/libcryptoa/src/lib.rs
use std::str::FromStr;
use age::{Decryptor, Encryptor, x25519};
use base64::{engine::general_purpose::STANDARD as B64, Engine as _};
use ed25519_dalek::{Signature, Signer, SigningKey, Verifier, VerifyingKey};
use secrecy::ExposeSecret;
use thiserror::Error;
#[derive(Error, Debug)]
pub enum AsymmetricCryptoError {
#[error("key parsing failed")]
ParseKey,
#[error("age crypto error: {0}")]
Age(String),
#[error("invalid utf-8 in plaintext")]
Utf8,
#[error("invalid signature length")]
SignatureLen,
#[error("signature verification failed")]
Verify,
#[error("base64 decoding failed: {0}")]
Base64(#[from] base64::DecodeError),
#[error("io error: {0}")]
Io(#[from] std::io::Error),
}
fn parse_recipient(s: &str) -> Result<x25519::Recipient, AsymmetricCryptoError> {
x25519::Recipient::from_str(s).map_err(|_| AsymmetricCryptoError::ParseKey)
}
fn parse_identity(s: &str) -> Result<x25519::Identity, AsymmetricCryptoError> {
x25519::Identity::from_str(s).map_err(|_| AsymmetricCryptoError::ParseKey)
}
fn parse_ed25519_signing_key(s: &str) -> Result<SigningKey, AsymmetricCryptoError> {
let bytes = B64.decode(s)?;
let key_bytes: [u8; 32] = bytes.try_into().map_err(|_| AsymmetricCryptoError::ParseKey)?;
Ok(SigningKey::from_bytes(&key_bytes))
}
fn parse_ed25519_verifying_key(s: &str) -> Result<VerifyingKey, AsymmetricCryptoError> {
let bytes = B64.decode(s)?;
let key_bytes: [u8; 32] = bytes.try_into().map_err(|_| AsymmetricCryptoError::ParseKey)?;
VerifyingKey::from_bytes(&key_bytes).map_err(|_| AsymmetricCryptoError::ParseKey)
}
pub fn gen_enc_keypair() -> (String, String) {
let id = x25519::Identity::generate();
let pk = id.to_public();
(pk.to_string(), id.to_string().expose_secret().to_string())
}
pub fn gen_sign_keypair() -> (String, String) {
let signing_key = SigningKey::generate(&mut rand::rngs::OsRng);
let verifying_key = signing_key.verifying_key();
(B64.encode(verifying_key.to_bytes()), B64.encode(signing_key.to_bytes()))
}
pub fn encrypt_b64(recipient_str: &str, msg: &str) -> Result<String, AsymmetricCryptoError> {
let recipient = parse_recipient(recipient_str)?;
let encryptor = Encryptor::with_recipients(vec![Box::new(recipient)])
.ok_or_else(|| AsymmetricCryptoError::Age("Failed to create encryptor".into()))?;
let mut encrypted = vec![];
let mut writer = encryptor.wrap_output(&mut encrypted)?;
std::io::Write::write_all(&mut writer, msg.as_bytes())?;
writer.finish()?;
Ok(B64.encode(encrypted))
}
pub fn decrypt_b64(identity_str: &str, ct_b64: &str) -> Result<String, AsymmetricCryptoError> {
let identity = parse_identity(identity_str)?;
let ct = B64.decode(ct_b64)?;
let decryptor = Decryptor::new(&ct[..]).map_err(|e| AsymmetricCryptoError::Age(e.to_string()))?;
let mut decrypted = vec![];
if let Decryptor::Recipients(d) = decryptor {
let mut reader = d.decrypt(std::iter::once(&identity as &dyn age::Identity))
.map_err(|e| AsymmetricCryptoError::Age(e.to_string()))?;
std::io::Read::read_to_end(&mut reader, &mut decrypted)?;
String::from_utf8(decrypted).map_err(|_| AsymmetricCryptoError::Utf8)
} else {
Err(AsymmetricCryptoError::Age("Passphrase decryption not supported".into()))
}
}
pub fn sign_b64(signing_secret_str: &str, msg: &str) -> Result<String, AsymmetricCryptoError> {
let signing_key = parse_ed25519_signing_key(signing_secret_str)?;
let signature = signing_key.sign(msg.as_bytes());
Ok(B64.encode(signature.to_bytes()))
}
pub fn verify_b64(verify_pub_str: &str, msg: &str, sig_b64: &str) -> Result<bool, AsymmetricCryptoError> {
let verifying_key = parse_ed25519_verifying_key(verify_pub_str)?;
let sig_bytes = B64.decode(sig_b64)?;
let signature = Signature::from_slice(&sig_bytes).map_err(|_| AsymmetricCryptoError::SignatureLen)?;
Ok(verifying_key.verify(msg.as_bytes(), &signature).is_ok())
}

View File

@@ -0,0 +1,15 @@
[package]
name = "libdbstorage"
version = "0.1.0"
edition = "2021"
[dependencies]
redb = { workspace = true }
serde = { workspace = true }
serde_json = { workspace = true }
thiserror = { workspace = true }
# Local Crate Dependencies
libcrypto = { path = "../libcrypto" }
tokio = { version = "1", features = ["full"] }
bincode = "1.3.3"

View File

@@ -0,0 +1,89 @@
use std::num::ParseIntError;
use tokio::sync::mpsc;
use redb;
use bincode;
// todo: more error types
#[derive(Debug)]
pub struct DBError(pub String);
impl From<std::io::Error> for DBError {
fn from(item: std::io::Error) -> Self {
DBError(item.to_string().clone())
}
}
impl From<ParseIntError> for DBError {
fn from(item: ParseIntError) -> Self {
DBError(item.to_string().clone())
}
}
impl From<std::str::Utf8Error> for DBError {
fn from(item: std::str::Utf8Error) -> Self {
DBError(item.to_string().clone())
}
}
impl From<std::string::FromUtf8Error> for DBError {
fn from(item: std::string::FromUtf8Error) -> Self {
DBError(item.to_string().clone())
}
}
impl From<redb::Error> for DBError {
fn from(item: redb::Error) -> Self {
DBError(item.to_string())
}
}
impl From<redb::DatabaseError> for DBError {
fn from(item: redb::DatabaseError) -> Self {
DBError(item.to_string())
}
}
impl From<redb::TransactionError> for DBError {
fn from(item: redb::TransactionError) -> Self {
DBError(item.to_string())
}
}
impl From<redb::TableError> for DBError {
fn from(item: redb::TableError) -> Self {
DBError(item.to_string())
}
}
impl From<redb::StorageError> for DBError {
fn from(item: redb::StorageError) -> Self {
DBError(item.to_string())
}
}
impl From<redb::CommitError> for DBError {
fn from(item: redb::CommitError) -> Self {
DBError(item.to_string())
}
}
impl From<Box<bincode::ErrorKind>> for DBError {
fn from(item: Box<bincode::ErrorKind>) -> Self {
DBError(item.to_string())
}
}
impl From<tokio::sync::mpsc::error::SendError<()>> for DBError {
fn from(item: mpsc::error::SendError<()>) -> Self {
DBError(item.to_string().clone())
}
}
impl From<serde_json::Error> for DBError {
fn from(item: serde_json::Error) -> Self {
DBError(item.to_string())
}
}

View File

@@ -0,0 +1,119 @@
// In crates/libdbstorage/src/lib.rs
use std::{
path::Path,
time::{SystemTime, UNIX_EPOCH},
};
use libcrypto::CryptoFactory; // Correct import
use redb::{Database, TableDefinition};
use serde::{Deserialize, Serialize};
pub mod error; // Declare the error module
pub use error::DBError; // Re-export for users of this crate
// Declare storage module
pub mod storage;
// Table definitions for different Redis data types
const TYPES_TABLE: TableDefinition<&str, &str> = TableDefinition::new("types");
const STRINGS_TABLE: TableDefinition<&str, &[u8]> = TableDefinition::new("strings");
const HASHES_TABLE: TableDefinition<(&str, &str), &[u8]> = TableDefinition::new("hashes");
const LISTS_TABLE: TableDefinition<&str, &[u8]> = TableDefinition::new("lists");
const STREAMS_META_TABLE: TableDefinition<&str, &[u8]> = TableDefinition::new("streams_meta");
const STREAMS_DATA_TABLE: TableDefinition<(&str, &str), &[u8]> = TableDefinition::new("streams_data");
const ENCRYPTED_TABLE: TableDefinition<&str, u8> = TableDefinition::new("encrypted");
const EXPIRATION_TABLE: TableDefinition<&str, u64> = TableDefinition::new("expiration");
#[derive(Serialize, Deserialize, Debug, Clone)]
pub struct StreamEntry {
pub fields: Vec<(String, String)>,
}
#[derive(Serialize, Deserialize, Debug, Clone)]
pub struct ListValue {
pub elements: Vec<String>,
}
#[inline]
pub fn now_in_millis() -> u128 {
let start = SystemTime::now();
let duration_since_epoch = start.duration_since(UNIX_EPOCH).unwrap();
duration_since_epoch.as_millis()
}
pub struct Storage {
db: Database,
crypto: Option<CryptoFactory>,
}
impl Storage {
pub fn new(path: impl AsRef<Path>, should_encrypt: bool, master_key: Option<&str>) -> Result<Self, DBError> {
let db = Database::create(path)?;
// Create tables if they don't exist
let write_txn = db.begin_write()?;
{
let _ = write_txn.open_table(TYPES_TABLE)?;
let _ = write_txn.open_table(STRINGS_TABLE)?;
let _ = write_txn.open_table(HASHES_TABLE)?;
let _ = write_txn.open_table(LISTS_TABLE)?;
let _ = write_txn.open_table(STREAMS_META_TABLE)?;
let _ = write_txn.open_table(STREAMS_DATA_TABLE)?;
let _ = write_txn.open_table(ENCRYPTED_TABLE)?;
let _ = write_txn.open_table(EXPIRATION_TABLE)?;
}
write_txn.commit()?;
// Check if database was previously encrypted
let read_txn = db.begin_read()?;
let encrypted_table = read_txn.open_table(ENCRYPTED_TABLE)?;
let was_encrypted = encrypted_table.get("encrypted")?.map(|v| v.value() == 1).unwrap_or(false);
drop(read_txn);
let crypto = if should_encrypt || was_encrypted {
if let Some(key) = master_key {
Some(CryptoFactory::new(key.as_bytes()))
} else {
return Err(DBError("Encryption requested but no master key provided".to_string()));
}
} else {
None
};
// If we're enabling encryption for the first time, mark it
if should_encrypt && !was_encrypted {
let write_txn = db.begin_write()?;
{
let mut encrypted_table = write_txn.open_table(ENCRYPTED_TABLE)?;
encrypted_table.insert("encrypted", &1u8)?;
}
write_txn.commit()?;
}
Ok(Storage {
db,
crypto,
})
}
pub fn is_encrypted(&self) -> bool {
self.crypto.is_some()
}
// Helper methods for encryption
fn encrypt_if_needed(&self, data: &[u8]) -> Result<Vec<u8>, DBError> {
if let Some(crypto) = &self.crypto {
Ok(crypto.encrypt(data))
} else {
Ok(data.to_vec())
}
}
fn decrypt_if_needed(&self, data: &[u8]) -> Result<Vec<u8>, DBError> {
if let Some(crypto) = &self.crypto {
Ok(crypto.decrypt(data).map_err(|e| DBError(e.to_string()))?)
} else {
Ok(data.to_vec())
}
}
}

View File

@@ -0,0 +1,4 @@
pub mod storage_basic;
pub mod storage_hset;
pub mod storage_lists;
pub mod storage_extra;

View File

@@ -0,0 +1,218 @@
use redb::{ReadableTable};
use crate::error::DBError;
use crate::{Storage, TYPES_TABLE, STRINGS_TABLE, HASHES_TABLE, LISTS_TABLE, STREAMS_META_TABLE, STREAMS_DATA_TABLE, EXPIRATION_TABLE, now_in_millis};
impl Storage {
pub fn flushdb(&self) -> Result<(), DBError> {
let write_txn = self.db.begin_write()?;
{
let mut types_table = write_txn.open_table(TYPES_TABLE)?;
let mut strings_table = write_txn.open_table(STRINGS_TABLE)?;
let mut hashes_table = write_txn.open_table(HASHES_TABLE)?;
let mut lists_table = write_txn.open_table(LISTS_TABLE)?;
let mut streams_meta_table = write_txn.open_table(STREAMS_META_TABLE)?;
let mut streams_data_table = write_txn.open_table(STREAMS_DATA_TABLE)?;
let mut expiration_table = write_txn.open_table(EXPIRATION_TABLE)?;
// inefficient, but there is no other way
let keys: Vec<String> = types_table.iter()?.map(|item| item.unwrap().0.value().to_string()).collect();
for key in keys {
types_table.remove(key.as_str())?;
}
let keys: Vec<String> = strings_table.iter()?.map(|item| item.unwrap().0.value().to_string()).collect();
for key in keys {
strings_table.remove(key.as_str())?;
}
let keys: Vec<(String, String)> = hashes_table
.iter()?
.map(|item| {
let binding = item.unwrap();
let (k, f) = binding.0.value();
(k.to_string(), f.to_string())
})
.collect();
for (key, field) in keys {
hashes_table.remove((key.as_str(), field.as_str()))?;
}
let keys: Vec<String> = lists_table.iter()?.map(|item| item.unwrap().0.value().to_string()).collect();
for key in keys {
lists_table.remove(key.as_str())?;
}
let keys: Vec<String> = streams_meta_table.iter()?.map(|item| item.unwrap().0.value().to_string()).collect();
for key in keys {
streams_meta_table.remove(key.as_str())?;
}
let keys: Vec<(String,String)> = streams_data_table.iter()?.map(|item| {
let binding = item.unwrap();
let (key, field) = binding.0.value();
(key.to_string(), field.to_string())
}).collect();
for (key, field) in keys {
streams_data_table.remove((key.as_str(), field.as_str()))?;
}
let keys: Vec<String> = expiration_table.iter()?.map(|item| item.unwrap().0.value().to_string()).collect();
for key in keys {
expiration_table.remove(key.as_str())?;
}
}
write_txn.commit()?;
Ok(())
}
pub fn get_key_type(&self, key: &str) -> Result<Option<String>, DBError> {
let read_txn = self.db.begin_read()?;
let table = read_txn.open_table(TYPES_TABLE)?;
// Before returning type, check for expiration
if let Some(type_val) = table.get(key)? {
if type_val.value() == "string" {
let expiration_table = read_txn.open_table(EXPIRATION_TABLE)?;
if let Some(expires_at) = expiration_table.get(key)? {
if now_in_millis() > expires_at.value() as u128 {
// The key is expired, so it effectively has no type
return Ok(None);
}
}
}
Ok(Some(type_val.value().to_string()))
} else {
Ok(None)
}
}
// ✅ ENCRYPTION APPLIED: Value is encrypted/decrypted
pub fn get(&self, key: &str) -> Result<Option<String>, DBError> {
let read_txn = self.db.begin_read()?;
let types_table = read_txn.open_table(TYPES_TABLE)?;
match types_table.get(key)? {
Some(type_val) if type_val.value() == "string" => {
// Check expiration first (unencrypted)
let expiration_table = read_txn.open_table(EXPIRATION_TABLE)?;
if let Some(expires_at) = expiration_table.get(key)? {
if now_in_millis() > expires_at.value() as u128 {
drop(read_txn);
self.del(key.to_string())?;
return Ok(None);
}
}
// Get and decrypt value
let strings_table = read_txn.open_table(STRINGS_TABLE)?;
match strings_table.get(key)? {
Some(data) => {
let decrypted = self.decrypt_if_needed(data.value())?;
let value = String::from_utf8(decrypted)?;
Ok(Some(value))
}
None => Ok(None),
}
}
_ => Ok(None),
}
}
// ✅ ENCRYPTION APPLIED: Value is encrypted before storage
pub fn set(&self, key: String, value: String) -> Result<(), DBError> {
let write_txn = self.db.begin_write()?;
{
let mut types_table = write_txn.open_table(TYPES_TABLE)?;
types_table.insert(key.as_str(), "string")?;
let mut strings_table = write_txn.open_table(STRINGS_TABLE)?;
// Only encrypt the value, not expiration
let encrypted = self.encrypt_if_needed(value.as_bytes())?;
strings_table.insert(key.as_str(), encrypted.as_slice())?;
// Remove any existing expiration since this is a regular SET
let mut expiration_table = write_txn.open_table(EXPIRATION_TABLE)?;
expiration_table.remove(key.as_str())?;
}
write_txn.commit()?;
Ok(())
}
// ✅ ENCRYPTION APPLIED: Value is encrypted before storage
pub fn setx(&self, key: String, value: String, expire_ms: u128) -> Result<(), DBError> {
let write_txn = self.db.begin_write()?;
{
let mut types_table = write_txn.open_table(TYPES_TABLE)?;
types_table.insert(key.as_str(), "string")?;
let mut strings_table = write_txn.open_table(STRINGS_TABLE)?;
// Only encrypt the value
let encrypted = self.encrypt_if_needed(value.as_bytes())?;
strings_table.insert(key.as_str(), encrypted.as_slice())?;
// Store expiration separately (unencrypted)
let mut expiration_table = write_txn.open_table(EXPIRATION_TABLE)?;
let expires_at = expire_ms + now_in_millis();
expiration_table.insert(key.as_str(), &(expires_at as u64))?;
}
write_txn.commit()?;
Ok(())
}
pub fn del(&self, key: String) -> Result<(), DBError> {
let write_txn = self.db.begin_write()?;
{
let mut types_table = write_txn.open_table(TYPES_TABLE)?;
let mut strings_table = write_txn.open_table(STRINGS_TABLE)?;
let mut hashes_table: redb::Table<(&str, &str), &[u8]> = write_txn.open_table(HASHES_TABLE)?;
let mut lists_table = write_txn.open_table(LISTS_TABLE)?;
// Remove from type table
types_table.remove(key.as_str())?;
// Remove from strings table
strings_table.remove(key.as_str())?;
// Remove all hash fields for this key
let mut to_remove = Vec::new();
let mut iter = hashes_table.iter()?;
while let Some(entry) = iter.next() {
let entry = entry?;
let (hash_key, field) = entry.0.value();
if hash_key == key.as_str() {
to_remove.push((hash_key.to_string(), field.to_string()));
}
}
drop(iter);
for (hash_key, field) in to_remove {
hashes_table.remove((hash_key.as_str(), field.as_str()))?;
}
// Remove from lists table
lists_table.remove(key.as_str())?;
// Also remove expiration
let mut expiration_table = write_txn.open_table(EXPIRATION_TABLE)?;
expiration_table.remove(key.as_str())?;
}
write_txn.commit()?;
Ok(())
}
pub fn keys(&self, pattern: &str) -> Result<Vec<String>, DBError> {
let read_txn = self.db.begin_read()?;
let table = read_txn.open_table(TYPES_TABLE)?;
let mut keys = Vec::new();
let mut iter = table.iter()?;
while let Some(entry) = iter.next() {
let key = entry?.0.value().to_string();
if pattern == "*" || super::storage_extra::glob_match(pattern, &key) {
keys.push(key);
}
}
Ok(keys)
}
}

View File

@@ -0,0 +1,168 @@
use redb::{ReadableTable};
use crate::error::DBError;
use crate::{Storage, TYPES_TABLE, STRINGS_TABLE, EXPIRATION_TABLE, now_in_millis};
impl Storage {
// ✅ ENCRYPTION APPLIED: Values are decrypted after retrieval
pub fn scan(&self, cursor: u64, pattern: Option<&str>, count: Option<u64>) -> Result<(u64, Vec<(String, String)>), DBError> {
let read_txn = self.db.begin_read()?;
let types_table = read_txn.open_table(TYPES_TABLE)?;
let strings_table = read_txn.open_table(STRINGS_TABLE)?;
let mut result = Vec::new();
let mut current_cursor = 0u64;
let limit = count.unwrap_or(10) as usize;
let mut iter = types_table.iter()?;
while let Some(entry) = iter.next() {
let entry = entry?;
let key = entry.0.value().to_string();
let key_type = entry.1.value().to_string();
if current_cursor >= cursor {
// Apply pattern matching if specified
let matches = if let Some(pat) = pattern {
glob_match(pat, &key)
} else {
true
};
if matches {
// For scan, we return key-value pairs for string types
if key_type == "string" {
if let Some(data) = strings_table.get(key.as_str())? {
let decrypted = self.decrypt_if_needed(data.value())?;
let value = String::from_utf8(decrypted)?;
result.push((key, value));
} else {
result.push((key, String::new()));
}
} else {
// For non-string types, just return the key with type as value
result.push((key, key_type));
}
if result.len() >= limit {
break;
}
}
}
current_cursor += 1;
}
let next_cursor = if result.len() < limit { 0 } else { current_cursor };
Ok((next_cursor, result))
}
pub fn ttl(&self, key: &str) -> Result<i64, DBError> {
let read_txn = self.db.begin_read()?;
let types_table = read_txn.open_table(TYPES_TABLE)?;
match types_table.get(key)? {
Some(type_val) if type_val.value() == "string" => {
let expiration_table = read_txn.open_table(EXPIRATION_TABLE)?;
match expiration_table.get(key)? {
Some(expires_at) => {
let now = now_in_millis();
let expires_at_ms = expires_at.value() as u128;
if now >= expires_at_ms {
Ok(-2) // Key has expired
} else {
Ok(((expires_at_ms - now) / 1000) as i64) // TTL in seconds
}
}
None => Ok(-1), // Key exists but has no expiration
}
}
Some(_) => Ok(-1), // Key exists but is not a string (no expiration support for other types)
None => Ok(-2), // Key does not exist
}
}
pub fn exists(&self, key: &str) -> Result<bool, DBError> {
let read_txn = self.db.begin_read()?;
let types_table = read_txn.open_table(TYPES_TABLE)?;
match types_table.get(key)? {
Some(type_val) if type_val.value() == "string" => {
// Check if string key has expired
let expiration_table = read_txn.open_table(EXPIRATION_TABLE)?;
if let Some(expires_at) = expiration_table.get(key)? {
if now_in_millis() > expires_at.value() as u128 {
return Ok(false); // Key has expired
}
}
Ok(true)
}
Some(_) => Ok(true), // Key exists and is not a string
None => Ok(false), // Key does not exist
}
}
}
// Utility function for glob pattern matching
pub fn glob_match(pattern: &str, text: &str) -> bool {
if pattern == "*" {
return true;
}
// Simple glob matching - supports * and ? wildcards
let pattern_chars: Vec<char> = pattern.chars().collect();
let text_chars: Vec<char> = text.chars().collect();
fn match_recursive(pattern: &[char], text: &[char], pi: usize, ti: usize) -> bool {
if pi >= pattern.len() {
return ti >= text.len();
}
if ti >= text.len() {
// Check if remaining pattern is all '*'
return pattern[pi..].iter().all(|&c| c == '*');
}
match pattern[pi] {
'*' => {
// Try matching zero or more characters
for i in ti..=text.len() {
if match_recursive(pattern, text, pi + 1, i) {
return true;
}
}
false
}
'?' => {
// Match exactly one character
match_recursive(pattern, text, pi + 1, ti + 1)
}
c => {
// Match exact character
if text[ti] == c {
match_recursive(pattern, text, pi + 1, ti + 1)
} else {
false
}
}
}
}
match_recursive(&pattern_chars, &text_chars, 0, 0)
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_glob_match() {
assert!(glob_match("*", "anything"));
assert!(glob_match("hello", "hello"));
assert!(!glob_match("hello", "world"));
assert!(glob_match("h*o", "hello"));
assert!(glob_match("h*o", "ho"));
assert!(!glob_match("h*o", "hi"));
assert!(glob_match("h?llo", "hello"));
assert!(!glob_match("h?llo", "hllo"));
assert!(glob_match("*test*", "this_is_a_test_string"));
assert!(!glob_match("*test*", "this_is_a_string"));
}
}

View File

@@ -0,0 +1,318 @@
use redb::{ReadableTable};
use crate::error::DBError;
use crate::{Storage, TYPES_TABLE, HASHES_TABLE};
impl Storage {
// ✅ ENCRYPTION APPLIED: Values are encrypted before storage
pub fn hset(&self, key: &str, pairs: Vec<(String, String)>) -> Result<i64, DBError> {
let write_txn = self.db.begin_write()?;
let mut new_fields = 0i64;
{
let mut types_table = write_txn.open_table(TYPES_TABLE)?;
let mut hashes_table = write_txn.open_table(HASHES_TABLE)?;
// Set the type to hash
types_table.insert(key, "hash")?;
for (field, value) in pairs {
// Check if field already exists
let exists = hashes_table.get((key, field.as_str()))?.is_some();
// Encrypt the value before storing
let encrypted = self.encrypt_if_needed(value.as_bytes())?;
hashes_table.insert((key, field.as_str()), encrypted.as_slice())?;
if !exists {
new_fields += 1;
}
}
}
write_txn.commit()?;
Ok(new_fields)
}
// ✅ ENCRYPTION APPLIED: Value is decrypted after retrieval
pub fn hget(&self, key: &str, field: &str) -> Result<Option<String>, DBError> {
let read_txn = self.db.begin_read()?;
let types_table = read_txn.open_table(TYPES_TABLE)?;
match types_table.get(key)? {
Some(type_val) if type_val.value() == "hash" => {
let hashes_table = read_txn.open_table(HASHES_TABLE)?;
match hashes_table.get((key, field))? {
Some(data) => {
let decrypted = self.decrypt_if_needed(data.value())?;
let value = String::from_utf8(decrypted)?;
Ok(Some(value))
}
None => Ok(None),
}
}
_ => Ok(None),
}
}
// ✅ ENCRYPTION APPLIED: All values are decrypted after retrieval
pub fn hgetall(&self, key: &str) -> Result<Vec<(String, String)>, DBError> {
let read_txn = self.db.begin_read()?;
let types_table = read_txn.open_table(TYPES_TABLE)?;
match types_table.get(key)? {
Some(type_val) if type_val.value() == "hash" => {
let hashes_table = read_txn.open_table(HASHES_TABLE)?;
let mut result = Vec::new();
let mut iter = hashes_table.iter()?;
while let Some(entry) = iter.next() {
let entry = entry?;
let (hash_key, field) = entry.0.value();
if hash_key == key {
let decrypted = self.decrypt_if_needed(entry.1.value())?;
let value = String::from_utf8(decrypted)?;
result.push((field.to_string(), value));
}
}
Ok(result)
}
_ => Ok(Vec::new()),
}
}
pub fn hdel(&self, key: &str, fields: Vec<String>) -> Result<i64, DBError> {
let write_txn = self.db.begin_write()?;
let mut deleted = 0i64;
// First check if key exists and is a hash
let is_hash = {
let types_table = write_txn.open_table(TYPES_TABLE)?;
let result = match types_table.get(key)? {
Some(type_val) => type_val.value() == "hash",
None => false,
};
result
};
if is_hash {
let mut hashes_table = write_txn.open_table(HASHES_TABLE)?;
for field in fields {
if hashes_table.remove((key, field.as_str()))?.is_some() {
deleted += 1;
}
}
// Check if hash is now empty and remove type if so
let mut has_fields = false;
let mut iter = hashes_table.iter()?;
while let Some(entry) = iter.next() {
let entry = entry?;
let (hash_key, _) = entry.0.value();
if hash_key == key {
has_fields = true;
break;
}
}
drop(iter);
if !has_fields {
let mut types_table = write_txn.open_table(TYPES_TABLE)?;
types_table.remove(key)?;
}
}
write_txn.commit()?;
Ok(deleted)
}
pub fn hexists(&self, key: &str, field: &str) -> Result<bool, DBError> {
let read_txn = self.db.begin_read()?;
let types_table = read_txn.open_table(TYPES_TABLE)?;
match types_table.get(key)? {
Some(type_val) if type_val.value() == "hash" => {
let hashes_table = read_txn.open_table(HASHES_TABLE)?;
Ok(hashes_table.get((key, field))?.is_some())
}
_ => Ok(false),
}
}
pub fn hkeys(&self, key: &str) -> Result<Vec<String>, DBError> {
let read_txn = self.db.begin_read()?;
let types_table = read_txn.open_table(TYPES_TABLE)?;
match types_table.get(key)? {
Some(type_val) if type_val.value() == "hash" => {
let hashes_table = read_txn.open_table(HASHES_TABLE)?;
let mut result = Vec::new();
let mut iter = hashes_table.iter()?;
while let Some(entry) = iter.next() {
let entry = entry?;
let (hash_key, field) = entry.0.value();
if hash_key == key {
result.push(field.to_string());
}
}
Ok(result)
}
_ => Ok(Vec::new()),
}
}
// ✅ ENCRYPTION APPLIED: All values are decrypted after retrieval
pub fn hvals(&self, key: &str) -> Result<Vec<String>, DBError> {
let read_txn = self.db.begin_read()?;
let types_table = read_txn.open_table(TYPES_TABLE)?;
match types_table.get(key)? {
Some(type_val) if type_val.value() == "hash" => {
let hashes_table = read_txn.open_table(HASHES_TABLE)?;
let mut result = Vec::new();
let mut iter = hashes_table.iter()?;
while let Some(entry) = iter.next() {
let entry = entry?;
let (hash_key, _) = entry.0.value();
if hash_key == key {
let decrypted = self.decrypt_if_needed(entry.1.value())?;
let value = String::from_utf8(decrypted)?;
result.push(value);
}
}
Ok(result)
}
_ => Ok(Vec::new()),
}
}
pub fn hlen(&self, key: &str) -> Result<i64, DBError> {
let read_txn = self.db.begin_read()?;
let types_table = read_txn.open_table(TYPES_TABLE)?;
match types_table.get(key)? {
Some(type_val) if type_val.value() == "hash" => {
let hashes_table = read_txn.open_table(HASHES_TABLE)?;
let mut count = 0i64;
let mut iter = hashes_table.iter()?;
while let Some(entry) = iter.next() {
let entry = entry?;
let (hash_key, _) = entry.0.value();
if hash_key == key {
count += 1;
}
}
Ok(count)
}
_ => Ok(0),
}
}
// ✅ ENCRYPTION APPLIED: Values are decrypted after retrieval
pub fn hmget(&self, key: &str, fields: Vec<String>) -> Result<Vec<Option<String>>, DBError> {
let read_txn = self.db.begin_read()?;
let types_table = read_txn.open_table(TYPES_TABLE)?;
match types_table.get(key)? {
Some(type_val) if type_val.value() == "hash" => {
let hashes_table = read_txn.open_table(HASHES_TABLE)?;
let mut result = Vec::new();
for field in fields {
match hashes_table.get((key, field.as_str()))? {
Some(data) => {
let decrypted = self.decrypt_if_needed(data.value())?;
let value = String::from_utf8(decrypted)?;
result.push(Some(value));
}
None => result.push(None),
}
}
Ok(result)
}
_ => Ok(fields.into_iter().map(|_| None).collect()),
}
}
// ✅ ENCRYPTION APPLIED: Value is encrypted before storage
pub fn hsetnx(&self, key: &str, field: &str, value: &str) -> Result<bool, DBError> {
let write_txn = self.db.begin_write()?;
let mut result = false;
{
let mut types_table = write_txn.open_table(TYPES_TABLE)?;
let mut hashes_table = write_txn.open_table(HASHES_TABLE)?;
// Check if field already exists
if hashes_table.get((key, field))?.is_none() {
// Set the type to hash
types_table.insert(key, "hash")?;
// Encrypt the value before storing
let encrypted = self.encrypt_if_needed(value.as_bytes())?;
hashes_table.insert((key, field), encrypted.as_slice())?;
result = true;
}
}
write_txn.commit()?;
Ok(result)
}
// ✅ ENCRYPTION APPLIED: Values are decrypted after retrieval
pub fn hscan(&self, key: &str, cursor: u64, pattern: Option<&str>, count: Option<u64>) -> Result<(u64, Vec<(String, String)>), DBError> {
let read_txn = self.db.begin_read()?;
let types_table = read_txn.open_table(TYPES_TABLE)?;
match types_table.get(key)? {
Some(type_val) if type_val.value() == "hash" => {
let hashes_table = read_txn.open_table(HASHES_TABLE)?;
let mut result = Vec::new();
let mut current_cursor = 0u64;
let limit = count.unwrap_or(10) as usize;
let mut iter = hashes_table.iter()?;
while let Some(entry) = iter.next() {
let entry = entry?;
let (hash_key, field) = entry.0.value();
if hash_key == key {
if current_cursor >= cursor {
let field_str = field.to_string();
// Apply pattern matching if specified
let matches = if let Some(pat) = pattern {
super::storage_extra::glob_match(pat, &field_str)
} else {
true
};
if matches {
let decrypted = self.decrypt_if_needed(entry.1.value())?;
let value = String::from_utf8(decrypted)?;
result.push((field_str, value));
if result.len() >= limit {
break;
}
}
}
current_cursor += 1;
}
}
let next_cursor = if result.len() < limit { 0 } else { current_cursor };
Ok((next_cursor, result))
}
_ => Ok((0, Vec::new())),
}
}
}

View File

@@ -0,0 +1,403 @@
use redb::{ReadableTable};
use crate::error::DBError;
use crate::{Storage, TYPES_TABLE, LISTS_TABLE};
impl Storage {
// ✅ ENCRYPTION APPLIED: Elements are encrypted before storage
pub fn lpush(&self, key: &str, elements: Vec<String>) -> Result<i64, DBError> {
let write_txn = self.db.begin_write()?;
let mut _length = 0i64;
{
let mut types_table = write_txn.open_table(TYPES_TABLE)?;
let mut lists_table = write_txn.open_table(LISTS_TABLE)?;
// Set the type to list
types_table.insert(key, "list")?;
// Get current list or create empty one
let mut list: Vec<String> = match lists_table.get(key)? {
Some(data) => {
let decrypted = self.decrypt_if_needed(data.value())?;
serde_json::from_slice(&decrypted)?
}
None => Vec::new(),
};
// Add elements to the front (left)
for element in elements.into_iter().rev() {
list.insert(0, element);
}
_length = list.len() as i64;
// Encrypt and store the updated list
let serialized = serde_json::to_vec(&list)?;
let encrypted = self.encrypt_if_needed(&serialized)?;
lists_table.insert(key, encrypted.as_slice())?;
}
write_txn.commit()?;
Ok(_length)
}
// ✅ ENCRYPTION APPLIED: Elements are encrypted before storage
pub fn rpush(&self, key: &str, elements: Vec<String>) -> Result<i64, DBError> {
let write_txn = self.db.begin_write()?;
let mut _length = 0i64;
{
let mut types_table = write_txn.open_table(TYPES_TABLE)?;
let mut lists_table = write_txn.open_table(LISTS_TABLE)?;
// Set the type to list
types_table.insert(key, "list")?;
// Get current list or create empty one
let mut list: Vec<String> = match lists_table.get(key)? {
Some(data) => {
let decrypted = self.decrypt_if_needed(data.value())?;
serde_json::from_slice(&decrypted)?
}
None => Vec::new(),
};
// Add elements to the end (right)
list.extend(elements);
_length = list.len() as i64;
// Encrypt and store the updated list
let serialized = serde_json::to_vec(&list)?;
let encrypted = self.encrypt_if_needed(&serialized)?;
lists_table.insert(key, encrypted.as_slice())?;
}
write_txn.commit()?;
Ok(_length)
}
// ✅ ENCRYPTION APPLIED: Elements are decrypted after retrieval and encrypted before storage
pub fn lpop(&self, key: &str, count: u64) -> Result<Vec<String>, DBError> {
let write_txn = self.db.begin_write()?;
let mut result = Vec::new();
// First check if key exists and is a list, and get the data
let list_data = {
let types_table = write_txn.open_table(TYPES_TABLE)?;
let lists_table = write_txn.open_table(LISTS_TABLE)?;
let result = match types_table.get(key)? {
Some(type_val) if type_val.value() == "list" => {
if let Some(data) = lists_table.get(key)? {
let decrypted = self.decrypt_if_needed(data.value())?;
let list: Vec<String> = serde_json::from_slice(&decrypted)?;
Some(list)
} else {
None
}
}
_ => None,
};
result
};
if let Some(mut list) = list_data {
let pop_count = std::cmp::min(count as usize, list.len());
for _ in 0..pop_count {
if !list.is_empty() {
result.push(list.remove(0));
}
}
let mut lists_table = write_txn.open_table(LISTS_TABLE)?;
if list.is_empty() {
// Remove the key if list is empty
lists_table.remove(key)?;
let mut types_table = write_txn.open_table(TYPES_TABLE)?;
types_table.remove(key)?;
} else {
// Encrypt and store the updated list
let serialized = serde_json::to_vec(&list)?;
let encrypted = self.encrypt_if_needed(&serialized)?;
lists_table.insert(key, encrypted.as_slice())?;
}
}
write_txn.commit()?;
Ok(result)
}
// ✅ ENCRYPTION APPLIED: Elements are decrypted after retrieval and encrypted before storage
pub fn rpop(&self, key: &str, count: u64) -> Result<Vec<String>, DBError> {
let write_txn = self.db.begin_write()?;
let mut result = Vec::new();
// First check if key exists and is a list, and get the data
let list_data = {
let types_table = write_txn.open_table(TYPES_TABLE)?;
let lists_table = write_txn.open_table(LISTS_TABLE)?;
let result = match types_table.get(key)? {
Some(type_val) if type_val.value() == "list" => {
if let Some(data) = lists_table.get(key)? {
let decrypted = self.decrypt_if_needed(data.value())?;
let list: Vec<String> = serde_json::from_slice(&decrypted)?;
Some(list)
} else {
None
}
}
_ => None,
};
result
};
if let Some(mut list) = list_data {
let pop_count = std::cmp::min(count as usize, list.len());
for _ in 0..pop_count {
if !list.is_empty() {
result.push(list.pop().unwrap());
}
}
let mut lists_table = write_txn.open_table(LISTS_TABLE)?;
if list.is_empty() {
// Remove the key if list is empty
lists_table.remove(key)?;
let mut types_table = write_txn.open_table(TYPES_TABLE)?;
types_table.remove(key)?;
} else {
// Encrypt and store the updated list
let serialized = serde_json::to_vec(&list)?;
let encrypted = self.encrypt_if_needed(&serialized)?;
lists_table.insert(key, encrypted.as_slice())?;
}
}
write_txn.commit()?;
Ok(result)
}
pub fn llen(&self, key: &str) -> Result<i64, DBError> {
let read_txn = self.db.begin_read()?;
let types_table = read_txn.open_table(TYPES_TABLE)?;
match types_table.get(key)? {
Some(type_val) if type_val.value() == "list" => {
let lists_table = read_txn.open_table(LISTS_TABLE)?;
match lists_table.get(key)? {
Some(data) => {
let decrypted = self.decrypt_if_needed(data.value())?;
let list: Vec<String> = serde_json::from_slice(&decrypted)?;
Ok(list.len() as i64)
}
None => Ok(0),
}
}
_ => Ok(0),
}
}
// ✅ ENCRYPTION APPLIED: Element is decrypted after retrieval
pub fn lindex(&self, key: &str, index: i64) -> Result<Option<String>, DBError> {
let read_txn = self.db.begin_read()?;
let types_table = read_txn.open_table(TYPES_TABLE)?;
match types_table.get(key)? {
Some(type_val) if type_val.value() == "list" => {
let lists_table = read_txn.open_table(LISTS_TABLE)?;
match lists_table.get(key)? {
Some(data) => {
let decrypted = self.decrypt_if_needed(data.value())?;
let list: Vec<String> = serde_json::from_slice(&decrypted)?;
let actual_index = if index < 0 {
list.len() as i64 + index
} else {
index
};
if actual_index >= 0 && (actual_index as usize) < list.len() {
Ok(Some(list[actual_index as usize].clone()))
} else {
Ok(None)
}
}
None => Ok(None),
}
}
_ => Ok(None),
}
}
// ✅ ENCRYPTION APPLIED: Elements are decrypted after retrieval
pub fn lrange(&self, key: &str, start: i64, stop: i64) -> Result<Vec<String>, DBError> {
let read_txn = self.db.begin_read()?;
let types_table = read_txn.open_table(TYPES_TABLE)?;
match types_table.get(key)? {
Some(type_val) if type_val.value() == "list" => {
let lists_table = read_txn.open_table(LISTS_TABLE)?;
match lists_table.get(key)? {
Some(data) => {
let decrypted = self.decrypt_if_needed(data.value())?;
let list: Vec<String> = serde_json::from_slice(&decrypted)?;
if list.is_empty() {
return Ok(Vec::new());
}
let len = list.len() as i64;
let start_idx = if start < 0 { std::cmp::max(0, len + start) } else { std::cmp::min(start, len) };
let stop_idx = if stop < 0 { std::cmp::max(-1, len + stop) } else { std::cmp::min(stop, len - 1) };
if start_idx > stop_idx || start_idx >= len {
return Ok(Vec::new());
}
let start_usize = start_idx as usize;
let stop_usize = (stop_idx + 1) as usize;
Ok(list[start_usize..std::cmp::min(stop_usize, list.len())].to_vec())
}
None => Ok(Vec::new()),
}
}
_ => Ok(Vec::new()),
}
}
// ✅ ENCRYPTION APPLIED: Elements are decrypted after retrieval and encrypted before storage
pub fn ltrim(&self, key: &str, start: i64, stop: i64) -> Result<(), DBError> {
let write_txn = self.db.begin_write()?;
// First check if key exists and is a list, and get the data
let list_data = {
let types_table = write_txn.open_table(TYPES_TABLE)?;
let lists_table = write_txn.open_table(LISTS_TABLE)?;
let result = match types_table.get(key)? {
Some(type_val) if type_val.value() == "list" => {
if let Some(data) = lists_table.get(key)? {
let decrypted = self.decrypt_if_needed(data.value())?;
let list: Vec<String> = serde_json::from_slice(&decrypted)?;
Some(list)
} else {
None
}
}
_ => None,
};
result
};
if let Some(list) = list_data {
if list.is_empty() {
write_txn.commit()?;
return Ok(());
}
let len = list.len() as i64;
let start_idx = if start < 0 { std::cmp::max(0, len + start) } else { std::cmp::min(start, len) };
let stop_idx = if stop < 0 { std::cmp::max(-1, len + stop) } else { std::cmp::min(stop, len - 1) };
let mut lists_table = write_txn.open_table(LISTS_TABLE)?;
if start_idx > stop_idx || start_idx >= len {
// Remove the entire list
lists_table.remove(key)?;
let mut types_table = write_txn.open_table(TYPES_TABLE)?;
types_table.remove(key)?;
} else {
let start_usize = start_idx as usize;
let stop_usize = (stop_idx + 1) as usize;
let trimmed = list[start_usize..std::cmp::min(stop_usize, list.len())].to_vec();
if trimmed.is_empty() {
lists_table.remove(key)?;
let mut types_table = write_txn.open_table(TYPES_TABLE)?;
types_table.remove(key)?;
} else {
// Encrypt and store the trimmed list
let serialized = serde_json::to_vec(&trimmed)?;
let encrypted = self.encrypt_if_needed(&serialized)?;
lists_table.insert(key, encrypted.as_slice())?;
}
}
}
write_txn.commit()?;
Ok(())
}
// ✅ ENCRYPTION APPLIED: Elements are decrypted after retrieval and encrypted before storage
pub fn lrem(&self, key: &str, count: i64, element: &str) -> Result<i64, DBError> {
let write_txn = self.db.begin_write()?;
let mut removed = 0i64;
// First check if key exists and is a list, and get the data
let list_data = {
let types_table = write_txn.open_table(TYPES_TABLE)?;
let lists_table = write_txn.open_table(LISTS_TABLE)?;
let result = match types_table.get(key)? {
Some(type_val) if type_val.value() == "list" => {
if let Some(data) = lists_table.get(key)? {
let decrypted = self.decrypt_if_needed(data.value())?;
let list: Vec<String> = serde_json::from_slice(&decrypted)?;
Some(list)
} else {
None
}
}
_ => None,
};
result
};
if let Some(mut list) = list_data {
if count == 0 {
// Remove all occurrences
let original_len = list.len();
list.retain(|x| x != element);
removed = (original_len - list.len()) as i64;
} else if count > 0 {
// Remove first count occurrences
let mut to_remove = count as usize;
list.retain(|x| {
if x == element && to_remove > 0 {
to_remove -= 1;
removed += 1;
false
} else {
true
}
});
} else {
// Remove last |count| occurrences
let mut to_remove = (-count) as usize;
for i in (0..list.len()).rev() {
if list[i] == element && to_remove > 0 {
list.remove(i);
to_remove -= 1;
removed += 1;
}
}
}
let mut lists_table = write_txn.open_table(LISTS_TABLE)?;
if list.is_empty() {
lists_table.remove(key)?;
let mut types_table = write_txn.open_table(TYPES_TABLE)?;
types_table.remove(key)?;
} else {
// Encrypt and store the updated list
let serialized = serde_json::to_vec(&list)?;
let encrypted = self.encrypt_if_needed(&serialized)?;
lists_table.insert(key, encrypted.as_slice())?;
}
}
write_txn.commit()?;
Ok(removed)
}
}

View File

@@ -0,0 +1,9 @@
[package]
name = "supervisor"
version = "0.1.0"
edition = "2021"
[dependencies]
# The supervisor will eventually depend on the herodb crate.
# We can add this dependency now.
# herodb = { path = "../herodb" }

View File

@@ -0,0 +1,4 @@
fn main() {
println!("Hello from the supervisor crate!");
// Supervisor logic will be implemented here.
}

View File

@@ -0,0 +1,18 @@
[package]
name = "supervisorrpc"
version = "0.1.0"
edition = "2021"
[[bin]]
name = "supervisorrpc"
path = "src/main.rs"
[dependencies]
# Example dependencies for an RPC server
# axum = "0.7"
# jsonrpsee = { version = "0.22", features = ["server"] }
# openrpc-types = "0.7"
tokio = { workspace = true }
redis = { version = "0.24", features = ["tokio-comp"] }
herocrypto = { path = "../herocrypto" }

View File

@@ -0,0 +1,12 @@
// To be implemented:
// 1. Define an OpenRPC schema for supervisor functions (e.g., server status, key rotation).
// 2. Implement an HTTP/TCP server (e.g., using Axum or jsonrpsee) that serves the schema
// and handles RPC calls.
// 3. Implement support for Unix domain sockets in addition to TCP.
// 4. Use the `herocrypto` or `redis-rs` crate to interact with the main `herodb` instance.
#[tokio::main]
async fn main() {
println!("Supervisor RPC server starting... (not implemented)");
// Server setup code will go here.
}

25
run_tests.sh Executable file
View File

@@ -0,0 +1,25 @@
#!/bin/bash
echo "🧪 Running HeroDB Redis Compatibility Tests"
echo "=========================================="
echo ""
echo "1⃣ Running Simple Redis Tests (4 tests)..."
echo "----------------------------------------------"
cargo test -p herodb --test simple_redis_test -- --nocapture
echo ""
echo "2⃣ Running Comprehensive Redis Integration Tests (13 tests)..."
echo "----------------------------------------------------------------"
cargo test -p herodb --test redis_integration_tests -- --nocapture
cargo test -p herodb --test redis_basic_client -- --nocapture
cargo test -p herodb --test debug_hset -- --nocapture
cargo test -p herodb --test debug_hset_simple -- --nocapture
echo ""
echo "3⃣ Running All Workspace Tests..."
echo "--------------------------------"
cargo test --workspace -- --nocapture
echo ""
echo "✅ Test execution completed!"

355
test_herodb.sh Executable file
View File

@@ -0,0 +1,355 @@
#!/bin/bash
# Test script for HeroDB - Redis-compatible database with redb backend
# This script starts the server and runs comprehensive tests
set -e
# Colors for output
RED='\033[0;31m'
GREEN='\033[0;32m'
YELLOW='\033[1;33m'
BLUE='\033[0;34m'
NC='\033[0m' # No Color
# Configuration
DB_DIR="./test_db"
PORT=6381
SERVER_PID=""
# Function to print colored output
print_status() {
echo -e "${BLUE}[INFO]${NC} $1"
}
print_success() {
echo -e "${GREEN}[SUCCESS]${NC} $1"
}
print_error() {
echo -e "${RED}[ERROR]${NC} $1"
}
print_warning() {
echo -e "${YELLOW}[WARNING]${NC} $1"
}
# Function to cleanup on exit
cleanup() {
if [ ! -z "$SERVER_PID" ]; then
print_status "Stopping HeroDB server (PID: $SERVER_PID)..."
kill $SERVER_PID 2>/dev/null || true
wait $SERVER_PID 2>/dev/null || true
fi
# Clean up test database
if [ -d "$DB_DIR" ]; then
print_status "Cleaning up test database directory..."
rm -rf "$DB_DIR"
fi
}
# Set trap to cleanup on script exit
trap cleanup EXIT
# Function to wait for server to start
wait_for_server() {
local max_attempts=30
local attempt=1
print_status "Waiting for server to start on port $PORT..."
while [ $attempt -le $max_attempts ]; do
if nc -z localhost $PORT 2>/dev/null; then
print_success "Server is ready!"
return 0
fi
echo -n "."
sleep 1
attempt=$((attempt + 1))
done
print_error "Server failed to start within $max_attempts seconds"
return 1
}
# Function to send Redis command and get response
redis_cmd() {
local cmd="$1"
local expected="$2"
print_status "Testing: $cmd"
local result=$(echo "$cmd" | redis-cli -p $PORT --raw 2>/dev/null || echo "ERROR")
if [ "$expected" != "" ] && [ "$result" != "$expected" ]; then
print_error "Expected: '$expected', Got: '$result'"
return 1
else
print_success "$cmd -> $result"
return 0
fi
}
# Function to test basic string operations
test_string_operations() {
print_status "=== Testing String Operations ==="
redis_cmd "PING" "PONG"
redis_cmd "SET mykey hello" "OK"
redis_cmd "GET mykey" "hello"
redis_cmd "SET counter 1" "OK"
redis_cmd "INCR counter" "2"
redis_cmd "INCR counter" "3"
redis_cmd "GET counter" "3"
redis_cmd "DEL mykey" "1"
redis_cmd "GET mykey" ""
redis_cmd "TYPE counter" "string"
redis_cmd "TYPE nonexistent" "none"
}
# Function to test hash operations
test_hash_operations() {
print_status "=== Testing Hash Operations ==="
# HSET and HGET
redis_cmd "HSET user:1 name John" "1"
redis_cmd "HSET user:1 age 30 city NYC" "2"
redis_cmd "HGET user:1 name" "John"
redis_cmd "HGET user:1 age" "30"
redis_cmd "HGET user:1 nonexistent" ""
# HGETALL
print_status "Testing HGETALL user:1"
redis_cmd "HGETALL user:1" ""
# HEXISTS
redis_cmd "HEXISTS user:1 name" "1"
redis_cmd "HEXISTS user:1 nonexistent" "0"
# HKEYS
print_status "Testing HKEYS user:1"
redis_cmd "HKEYS user:1" ""
# HVALS
print_status "Testing HVALS user:1"
redis_cmd "HVALS user:1" ""
# HLEN
redis_cmd "HLEN user:1" "3"
# HMGET
print_status "Testing HMGET user:1 name age"
redis_cmd "HMGET user:1 name age" ""
# HSETNX
redis_cmd "HSETNX user:1 name Jane" "0" # Should not set, field exists
redis_cmd "HSETNX user:1 email john@example.com" "1" # Should set, new field
redis_cmd "HGET user:1 email" "john@example.com"
# HDEL
redis_cmd "HDEL user:1 age city" "2"
redis_cmd "HLEN user:1" "2"
redis_cmd "HEXISTS user:1 age" "0"
# Test type checking
redis_cmd "SET stringkey value" "OK"
print_status "Testing WRONGTYPE error on string key"
redis_cmd "HGET stringkey field" "" # Should return WRONGTYPE error
}
# Function to test configuration commands
test_config_operations() {
print_status "=== Testing Configuration Operations ==="
print_status "Testing CONFIG GET dir"
redis_cmd "CONFIG GET dir" ""
print_status "Testing CONFIG GET dbfilename"
redis_cmd "CONFIG GET dbfilename" ""
}
# Function to test transaction operations
test_transaction_operations() {
print_status "=== Testing Transaction Operations ==="
redis_cmd "MULTI" "OK"
redis_cmd "SET tx_key1 value1" "QUEUED"
redis_cmd "SET tx_key2 value2" "QUEUED"
redis_cmd "INCR counter" "QUEUED"
print_status "Testing EXEC"
redis_cmd "EXEC" ""
redis_cmd "GET tx_key1" "value1"
redis_cmd "GET tx_key2" "value2"
# Test DISCARD
redis_cmd "MULTI" "OK"
redis_cmd "SET discard_key value" "QUEUED"
redis_cmd "DISCARD" "OK"
redis_cmd "GET discard_key" ""
}
# Function to test keys operations
test_keys_operations() {
print_status "=== Testing Keys Operations ==="
print_status "Testing KEYS *"
redis_cmd "KEYS *" ""
}
# Function to test info operations
test_info_operations() {
print_status "=== Testing Info Operations ==="
print_status "Testing INFO"
redis_cmd "INFO" ""
print_status "Testing INFO replication"
redis_cmd "INFO replication" ""
}
# Function to test expiration
test_expiration() {
print_status "=== Testing Expiration ==="
redis_cmd "SET expire_key value" "OK"
redis_cmd "SET expire_px_key value PX 1000" "OK" # 1 second
redis_cmd "SET expire_ex_key value EX 1" "OK" # 1 second
redis_cmd "GET expire_key" "value"
redis_cmd "GET expire_px_key" "value"
redis_cmd "GET expire_ex_key" "value"
print_status "Waiting 2 seconds for expiration..."
sleep 2
redis_cmd "GET expire_key" "value" # Should still exist
redis_cmd "GET expire_px_key" "" # Should be expired
redis_cmd "GET expire_ex_key" "" # Should be expired
}
# Function to test SCAN operations
test_scan_operations() {
print_status "=== Testing SCAN Operations ==="
# Set up test data for scanning
redis_cmd "SET scan_test1 value1" "OK"
redis_cmd "SET scan_test2 value2" "OK"
redis_cmd "SET scan_test3 value3" "OK"
redis_cmd "SET other_key other_value" "OK"
redis_cmd "HSET scan_hash field1 value1" "1"
# Test basic SCAN
print_status "Testing basic SCAN with cursor 0"
redis_cmd "SCAN 0" ""
# Test SCAN with MATCH pattern
print_status "Testing SCAN with MATCH pattern"
redis_cmd "SCAN 0 MATCH scan_test*" ""
# Test SCAN with COUNT
print_status "Testing SCAN with COUNT 2"
redis_cmd "SCAN 0 COUNT 2" ""
# Test SCAN with both MATCH and COUNT
print_status "Testing SCAN with MATCH and COUNT"
redis_cmd "SCAN 0 MATCH scan_* COUNT 1" ""
# Test SCAN continuation with more keys
print_status "Setting up more keys for continuation test"
redis_cmd "SET scan_key1 val1" "OK"
redis_cmd "SET scan_key2 val2" "OK"
redis_cmd "SET scan_key3 val3" "OK"
redis_cmd "SET scan_key4 val4" "OK"
redis_cmd "SET scan_key5 val5" "OK"
print_status "Testing SCAN with small COUNT for pagination"
redis_cmd "SCAN 0 COUNT 3" ""
# Clean up SCAN test data
print_status "Cleaning up SCAN test data"
redis_cmd "DEL scan_test1" "1"
redis_cmd "DEL scan_test2" "1"
redis_cmd "DEL scan_test3" "1"
redis_cmd "DEL other_key" "1"
redis_cmd "DEL scan_hash" "1"
redis_cmd "DEL scan_key1" "1"
redis_cmd "DEL scan_key2" "1"
redis_cmd "DEL scan_key3" "1"
redis_cmd "DEL scan_key4" "1"
redis_cmd "DEL scan_key5" "1"
}
# Main execution
main() {
print_status "Starting HeroDB comprehensive test suite..."
# Build the project
print_status "Building HeroDB..."
if ! cargo build -p herodb --release; then
print_error "Failed to build HeroDB"
exit 1
fi
# Create test database directory
mkdir -p "$DB_DIR"
# Start the server
print_status "Starting HeroDB server..."
./target/release/herodb --dir "$DB_DIR" --port $PORT &
SERVER_PID=$!
# Wait for server to start
if ! wait_for_server; then
print_error "Failed to start server"
exit 1
fi
# Run tests
local failed_tests=0
test_string_operations || failed_tests=$((failed_tests + 1))
test_hash_operations || failed_tests=$((failed_tests + 1))
test_config_operations || failed_tests=$((failed_tests + 1))
test_transaction_operations || failed_tests=$((failed_tests + 1))
test_keys_operations || failed_tests=$((failed_tests + 1))
test_info_operations || failed_tests=$((failed_tests + 1))
test_expiration || failed_tests=$((failed_tests + 1))
test_scan_operations || failed_tests=$((failed_tests + 1))
# Summary
echo
print_status "=== Test Summary ==="
if [ $failed_tests -eq 0 ]; then
print_success "All tests completed! Some may have warnings due to protocol differences."
print_success "HeroDB is working with persistent redb storage!"
else
print_warning "$failed_tests test categories had issues"
print_warning "Check the output above for details"
fi
print_status "Database file created at: $DB_DIR/herodb.redb"
print_status "Server logs and any errors are shown above"
}
# Check dependencies
check_dependencies() {
if ! command -v cargo &> /dev/null; then
print_error "cargo is required but not installed"
exit 1
fi
if ! command -v nc &> /dev/null; then
print_warning "netcat (nc) not found - some tests may not work properly"
fi
if ! command -v redis-cli &> /dev/null; then
print_warning "redis-cli not found - using netcat fallback"
fi
}
# Run dependency check and main function
check_dependencies
main "$@"