...
This commit is contained in:
54
src/cmd.rs
54
src/cmd.rs
@@ -1,4 +1,5 @@
|
||||
use crate::{error::DBError, protocol::Protocol, server::Server};
|
||||
use serde::Serialize;
|
||||
|
||||
#[derive(Debug, Clone)]
|
||||
pub enum Cmd {
|
||||
@@ -444,7 +445,7 @@ impl Cmd {
|
||||
Cmd::Del(k) => del_cmd(server, k).await,
|
||||
Cmd::ConfigGet(name) => config_get_cmd(name, server),
|
||||
Cmd::Keys => keys_cmd(server).await,
|
||||
Cmd::Info(section) => info_cmd(section),
|
||||
Cmd::Info(section) => info_cmd(server, section).await,
|
||||
Cmd::Type(k) => type_cmd(server, k).await,
|
||||
Cmd::Incr(key) => incr_cmd(server, key).await,
|
||||
Cmd::Multi => {
|
||||
@@ -640,19 +641,23 @@ async fn incr_cmd(server: &Server, key: &String) -> Result<Protocol, DBError> {
|
||||
}
|
||||
|
||||
fn config_get_cmd(name: &String, server: &Server) -> Result<Protocol, DBError> {
|
||||
let mut result = Vec::new();
|
||||
result.push(Protocol::BulkString(name.clone()));
|
||||
|
||||
match name.as_str() {
|
||||
"dir" => Ok(Protocol::Array(vec![
|
||||
Protocol::BulkString(name.clone()),
|
||||
Protocol::BulkString(server.option.dir.clone()),
|
||||
])),
|
||||
"dbfilename" => Ok(Protocol::Array(vec![
|
||||
Protocol::BulkString(name.clone()),
|
||||
Protocol::BulkString(format!("{}.db", server.selected_db)),
|
||||
])),
|
||||
"databases" => Ok(Protocol::Array(vec![
|
||||
Protocol::BulkString(name.clone()),
|
||||
Protocol::BulkString(server.option.max_databases.unwrap_or(0).to_string()),
|
||||
])),
|
||||
"dir" => {
|
||||
result.push(Protocol::BulkString(server.option.dir.clone()));
|
||||
Ok(Protocol::Array(result))
|
||||
}
|
||||
"dbfilename" => {
|
||||
result.push(Protocol::BulkString(format!("{}.db", server.selected_db)));
|
||||
Ok(Protocol::Array(result))
|
||||
},
|
||||
"databases" => {
|
||||
// This is hardcoded, as the feature was removed
|
||||
result.push(Protocol::BulkString("16".to_string()));
|
||||
Ok(Protocol::Array(result))
|
||||
},
|
||||
_ => Ok(Protocol::Array(vec![])),
|
||||
}
|
||||
}
|
||||
@@ -664,7 +669,24 @@ async fn keys_cmd(server: &Server) -> Result<Protocol, DBError> {
|
||||
))
|
||||
}
|
||||
|
||||
fn info_cmd(section: &Option<String>) -> Result<Protocol, DBError> {
|
||||
#[derive(Serialize)]
|
||||
struct ServerInfo {
|
||||
redis_version: String,
|
||||
encrypted: bool,
|
||||
}
|
||||
|
||||
async fn info_cmd(server: &Server, section: &Option<String>) -> Result<Protocol, DBError> {
|
||||
let info = ServerInfo {
|
||||
redis_version: "7.0.0".to_string(),
|
||||
encrypted: server.current_storage()?.is_encrypted(),
|
||||
};
|
||||
|
||||
let mut info_string = String::new();
|
||||
info_string.push_str(&format!("# Server\n"));
|
||||
info_string.push_str(&format!("redis_version:{}\n", info.redis_version));
|
||||
info_string.push_str(&format!("encrypted:{}\n", if info.encrypted { 1 } else { 0 }));
|
||||
|
||||
|
||||
match section {
|
||||
Some(s) => match s.as_str() {
|
||||
"replication" => Ok(Protocol::BulkString(
|
||||
@@ -672,7 +694,9 @@ fn info_cmd(section: &Option<String>) -> Result<Protocol, DBError> {
|
||||
)),
|
||||
_ => Err(DBError(format!("unsupported section {:?}", s))),
|
||||
},
|
||||
None => Ok(Protocol::BulkString("# Server\nredis_version:7.0.0\n".to_string())),
|
||||
None => {
|
||||
Ok(Protocol::BulkString(info_string))
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
|
@@ -22,13 +22,14 @@ struct Args {
|
||||
#[arg(long)]
|
||||
debug: bool,
|
||||
|
||||
/// Maximum number of logical databases (None = unlimited)
|
||||
#[arg(long)]
|
||||
max_databases: Option<u64>,
|
||||
|
||||
/// Master encryption key for encrypted databases
|
||||
#[arg(long)]
|
||||
encryption_key: Option<String>,
|
||||
|
||||
/// Encrypt the database
|
||||
#[arg(long)]
|
||||
encrypt: bool,
|
||||
}
|
||||
|
||||
#[tokio::main]
|
||||
@@ -48,8 +49,8 @@ async fn main() {
|
||||
dir: args.dir,
|
||||
port,
|
||||
debug: args.debug,
|
||||
max_databases: args.max_databases,
|
||||
encryption_key: args.encryption_key,
|
||||
encrypt: args.encrypt,
|
||||
};
|
||||
|
||||
// new server
|
||||
|
@@ -3,6 +3,6 @@ pub struct DBOption {
|
||||
pub dir: String,
|
||||
pub port: u16,
|
||||
pub debug: bool,
|
||||
pub max_databases: Option<u64>, // None = unlimited, Some(n) = limit to n
|
||||
pub encrypt: bool,
|
||||
pub encryption_key: Option<String>, // Master encryption key
|
||||
}
|
||||
|
@@ -35,12 +35,6 @@ impl Server {
|
||||
return Ok(storage.clone());
|
||||
}
|
||||
|
||||
// Check database limit if set
|
||||
if let Some(max_db) = self.option.max_databases {
|
||||
if self.selected_db >= max_db {
|
||||
return Err(DBError(format!("DB index {} is out of range (max: {})", self.selected_db, max_db - 1)));
|
||||
}
|
||||
}
|
||||
|
||||
// Create new database file
|
||||
let db_file_path = std::path::PathBuf::from(self.option.dir.clone())
|
||||
@@ -58,10 +52,8 @@ impl Server {
|
||||
Ok(storage)
|
||||
}
|
||||
|
||||
fn should_encrypt_db(&self, db_index: u64) -> bool {
|
||||
// You can implement logic here to determine which databases should be encrypted
|
||||
// For now, let's say databases with even numbers are encrypted if key is provided
|
||||
self.option.encryption_key.is_some() && db_index % 2 == 0
|
||||
fn should_encrypt_db(&self, _db_index: u64) -> bool {
|
||||
self.option.encrypt
|
||||
}
|
||||
|
||||
pub async fn handle(
|
||||
|
149
src/storage.rs
149
src/storage.rs
@@ -104,17 +104,12 @@ fn glob_match(pattern: &str, text: &str) -> bool {
|
||||
// Table definitions for different Redis data types
|
||||
const TYPES_TABLE: TableDefinition<&str, &str> = TableDefinition::new("types");
|
||||
const STRINGS_TABLE: TableDefinition<&str, &[u8]> = TableDefinition::new("strings");
|
||||
const HASHES_TABLE: TableDefinition<(&str, &str), &str> = TableDefinition::new("hashes");
|
||||
const HASHES_TABLE: TableDefinition<(&str, &str), &[u8]> = TableDefinition::new("hashes");
|
||||
const LISTS_TABLE: TableDefinition<&str, &[u8]> = TableDefinition::new("lists");
|
||||
const STREAMS_META_TABLE: TableDefinition<&str, &[u8]> = TableDefinition::new("streams_meta");
|
||||
const STREAMS_DATA_TABLE: TableDefinition<(&str, &str), &[u8]> = TableDefinition::new("streams_data");
|
||||
const ENCRYPTED_TABLE: TableDefinition<&str, u8> = TableDefinition::new("encrypted");
|
||||
|
||||
#[derive(Serialize, Deserialize, Debug, Clone)]
|
||||
pub struct StringValue {
|
||||
pub value: String,
|
||||
pub expires_at_ms: Option<u128>,
|
||||
}
|
||||
const EXPIRATION_TABLE: TableDefinition<&str, u64> = TableDefinition::new("expiration");
|
||||
|
||||
#[derive(Serialize, Deserialize, Debug, Clone)]
|
||||
pub struct StreamEntry {
|
||||
@@ -153,6 +148,7 @@ impl Storage {
|
||||
let _ = write_txn.open_table(STREAMS_META_TABLE)?;
|
||||
let _ = write_txn.open_table(STREAMS_DATA_TABLE)?;
|
||||
let _ = write_txn.open_table(ENCRYPTED_TABLE)?;
|
||||
let _ = write_txn.open_table(EXPIRATION_TABLE)?;
|
||||
}
|
||||
write_txn.commit()?;
|
||||
|
||||
@@ -218,6 +214,7 @@ impl Storage {
|
||||
let mut lists_table = write_txn.open_table(LISTS_TABLE)?;
|
||||
let mut streams_meta_table = write_txn.open_table(STREAMS_META_TABLE)?;
|
||||
let mut streams_data_table = write_txn.open_table(STREAMS_DATA_TABLE)?;
|
||||
let mut expiration_table = write_txn.open_table(EXPIRATION_TABLE)?;
|
||||
|
||||
// inefficient, but there is no other way
|
||||
let keys: Vec<String> = types_table.iter()?.map(|item| item.unwrap().0.value().to_string()).collect();
|
||||
@@ -252,6 +249,10 @@ impl Storage {
|
||||
for (key, field) in keys {
|
||||
streams_data_table.remove((key.as_str(), field.as_str()))?;
|
||||
}
|
||||
let keys: Vec<String> = expiration_table.iter()?.map(|item| item.unwrap().0.value().to_string()).collect();
|
||||
for key in keys {
|
||||
expiration_table.remove(key.as_str())?;
|
||||
}
|
||||
}
|
||||
write_txn.commit()?;
|
||||
Ok(())
|
||||
@@ -274,22 +275,23 @@ impl Storage {
|
||||
let types_table = read_txn.open_table(TYPES_TABLE)?;
|
||||
match types_table.get(key)? {
|
||||
Some(type_val) if type_val.value() == "string" => {
|
||||
// Check expiration first (unencrypted)
|
||||
let expiration_table = read_txn.open_table(EXPIRATION_TABLE)?;
|
||||
if let Some(expires_at) = expiration_table.get(key)? {
|
||||
if now_in_millis() > expires_at.value() as u128 {
|
||||
drop(read_txn);
|
||||
self.del(key.to_string())?;
|
||||
return Ok(None);
|
||||
}
|
||||
}
|
||||
|
||||
// Get and decrypt value
|
||||
let strings_table = read_txn.open_table(STRINGS_TABLE)?;
|
||||
match strings_table.get(key)? {
|
||||
Some(data) => {
|
||||
let decrypted = self.decrypt_if_needed(data.value())?;
|
||||
let string_value: StringValue = bincode::deserialize(&decrypted)?;
|
||||
|
||||
// Check if expired
|
||||
if let Some(expires_at) = string_value.expires_at_ms {
|
||||
if now_in_millis() > expires_at {
|
||||
drop(read_txn);
|
||||
self.del(key.to_string())?;
|
||||
return Ok(None);
|
||||
}
|
||||
}
|
||||
|
||||
Ok(Some(string_value.value))
|
||||
let value = String::from_utf8(decrypted)?;
|
||||
Ok(Some(value))
|
||||
}
|
||||
None => Ok(None),
|
||||
}
|
||||
@@ -310,13 +312,13 @@ impl Storage {
|
||||
types_table.insert(key.as_str(), "string")?;
|
||||
|
||||
let mut strings_table = write_txn.open_table(STRINGS_TABLE)?;
|
||||
let string_value = StringValue {
|
||||
value,
|
||||
expires_at_ms: None,
|
||||
};
|
||||
let serialized = bincode::serialize(&string_value)?;
|
||||
let encrypted = self.encrypt_if_needed(&serialized)?;
|
||||
// Only encrypt the value, not expiration
|
||||
let encrypted = self.encrypt_if_needed(value.as_bytes())?;
|
||||
strings_table.insert(key.as_str(), encrypted.as_slice())?;
|
||||
|
||||
// Remove any existing expiration since this is a regular SET
|
||||
let mut expiration_table = write_txn.open_table(EXPIRATION_TABLE)?;
|
||||
expiration_table.remove(key.as_str())?;
|
||||
}
|
||||
|
||||
write_txn.commit()?;
|
||||
@@ -331,13 +333,14 @@ impl Storage {
|
||||
types_table.insert(key.as_str(), "string")?;
|
||||
|
||||
let mut strings_table = write_txn.open_table(STRINGS_TABLE)?;
|
||||
let string_value = StringValue {
|
||||
value,
|
||||
expires_at_ms: Some(expire_ms + now_in_millis()),
|
||||
};
|
||||
let serialized = bincode::serialize(&string_value)?;
|
||||
let encrypted = self.encrypt_if_needed(&serialized)?;
|
||||
// Only encrypt the value
|
||||
let encrypted = self.encrypt_if_needed(value.as_bytes())?;
|
||||
strings_table.insert(key.as_str(), encrypted.as_slice())?;
|
||||
|
||||
// Store expiration separately (unencrypted)
|
||||
let mut expiration_table = write_txn.open_table(EXPIRATION_TABLE)?;
|
||||
let expires_at = expire_ms + now_in_millis();
|
||||
expiration_table.insert(key.as_str(), &(expires_at as u64))?;
|
||||
}
|
||||
|
||||
write_txn.commit()?;
|
||||
@@ -377,6 +380,10 @@ impl Storage {
|
||||
|
||||
// Remove from lists table
|
||||
lists_table.remove(key.as_str())?;
|
||||
|
||||
// Also remove expiration
|
||||
let mut expiration_table = write_txn.open_table(EXPIRATION_TABLE)?;
|
||||
expiration_table.remove(key.as_str())?;
|
||||
}
|
||||
|
||||
write_txn.commit()?;
|
||||
@@ -427,7 +434,11 @@ impl Storage {
|
||||
|
||||
for (field, value) in pairs {
|
||||
let existed = hashes_table.get((key, field.as_str()))?.is_some();
|
||||
hashes_table.insert((key, field.as_str()), value.as_str())?;
|
||||
|
||||
// Encrypt the value before storing
|
||||
let encrypted = self.encrypt_if_needed(value.as_bytes())?;
|
||||
hashes_table.insert((key, field.as_str()), encrypted.as_slice())?;
|
||||
|
||||
if !existed {
|
||||
new_fields += 1;
|
||||
}
|
||||
@@ -447,7 +458,11 @@ impl Storage {
|
||||
Some(type_val) if type_val.value() == "hash" => {
|
||||
let hashes_table = read_txn.open_table(HASHES_TABLE)?;
|
||||
match hashes_table.get((key, field))? {
|
||||
Some(value) => Ok(Some(value.value().to_string())),
|
||||
Some(data) => {
|
||||
let decrypted = self.decrypt_if_needed(data.value())?;
|
||||
let value = String::from_utf8(decrypted)?;
|
||||
Ok(Some(value))
|
||||
}
|
||||
None => Ok(None),
|
||||
}
|
||||
}
|
||||
@@ -472,7 +487,9 @@ impl Storage {
|
||||
let (hash_key, field) = entry.0.value();
|
||||
let value = entry.1.value();
|
||||
if hash_key == key {
|
||||
result.push((field.to_string(), value.to_string()));
|
||||
let decrypted = self.decrypt_if_needed(value)?;
|
||||
let value_str = String::from_utf8(decrypted)?;
|
||||
result.push((field.to_string(), value_str));
|
||||
}
|
||||
}
|
||||
|
||||
@@ -563,7 +580,9 @@ impl Storage {
|
||||
let (hash_key, _) = entry.0.value();
|
||||
let value = entry.1.value();
|
||||
if hash_key == key {
|
||||
result.push(value.to_string());
|
||||
let decrypted = self.decrypt_if_needed(value)?;
|
||||
let value_str = String::from_utf8(decrypted)?;
|
||||
result.push(value_str);
|
||||
}
|
||||
}
|
||||
|
||||
@@ -610,7 +629,11 @@ impl Storage {
|
||||
|
||||
for field in fields {
|
||||
match hashes_table.get((key, field.as_str()))? {
|
||||
Some(value) => result.push(Some(value.value().to_string())),
|
||||
Some(data) => {
|
||||
let decrypted = self.decrypt_if_needed(data.value())?;
|
||||
let value = String::from_utf8(decrypted)?;
|
||||
result.push(Some(value));
|
||||
}
|
||||
None => result.push(None),
|
||||
}
|
||||
}
|
||||
@@ -649,7 +672,8 @@ impl Storage {
|
||||
|
||||
// Check if field already exists
|
||||
if hashes_table.get((key, field))?.is_none() {
|
||||
hashes_table.insert((key, field), value)?;
|
||||
let encrypted_value = self.encrypt_if_needed(value.as_bytes())?;
|
||||
hashes_table.insert((key, field), encrypted_value.as_slice())?;
|
||||
result = true;
|
||||
}
|
||||
}
|
||||
@@ -769,8 +793,10 @@ impl Storage {
|
||||
};
|
||||
|
||||
if matches {
|
||||
let decrypted = self.decrypt_if_needed(value)?;
|
||||
let value_str = String::from_utf8(decrypted)?;
|
||||
fields.push(field.to_string());
|
||||
fields.push(value.to_string());
|
||||
fields.push(value_str);
|
||||
returned_fields += 1;
|
||||
|
||||
if returned_fields >= count {
|
||||
@@ -792,30 +818,24 @@ impl Storage {
|
||||
pub fn ttl(&self, key: &str) -> Result<i64, DBError> {
|
||||
let read_txn = self.db.begin_read()?;
|
||||
|
||||
// Check if key exists
|
||||
let types_table = read_txn.open_table(TYPES_TABLE)?;
|
||||
match types_table.get(key)? {
|
||||
Some(type_val) if type_val.value() == "string" => {
|
||||
let strings_table = read_txn.open_table(STRINGS_TABLE)?;
|
||||
match strings_table.get(key)? {
|
||||
Some(data) => {
|
||||
let string_value: StringValue = bincode::deserialize(data.value())?;
|
||||
match string_value.expires_at_ms {
|
||||
Some(expires_at) => {
|
||||
let now = now_in_millis();
|
||||
if now > expires_at {
|
||||
Ok(-2) // Key expired
|
||||
} else {
|
||||
Ok(((expires_at - now) / 1000) as i64) // TTL in seconds
|
||||
}
|
||||
}
|
||||
None => Ok(-1), // No expiration
|
||||
let expiration_table = read_txn.open_table(EXPIRATION_TABLE)?;
|
||||
match expiration_table.get(key)? {
|
||||
Some(expires_at) => {
|
||||
let now = now_in_millis();
|
||||
let expires_at = expires_at.value() as u128;
|
||||
if now > expires_at {
|
||||
Ok(-2) // Key expired
|
||||
} else {
|
||||
Ok(((expires_at - now) / 1000) as i64) // TTL in seconds
|
||||
}
|
||||
}
|
||||
None => Ok(-2), // Key doesn't exist
|
||||
None => Ok(-1), // No expiration
|
||||
}
|
||||
}
|
||||
Some(_) => Ok(-1), // Other types don't have TTL implemented yet
|
||||
Some(_) => Ok(-1), // Other types don't have TTL
|
||||
None => Ok(-2), // Key doesn't exist
|
||||
}
|
||||
}
|
||||
@@ -825,18 +845,13 @@ impl Storage {
|
||||
let types_table = read_txn.open_table(TYPES_TABLE)?;
|
||||
|
||||
match types_table.get(key)? {
|
||||
Some(_) => {
|
||||
Some(type_val) => {
|
||||
// For string types, check if not expired
|
||||
if let Some(type_val) = types_table.get(key)? {
|
||||
if type_val.value() == "string" {
|
||||
let strings_table = read_txn.open_table(STRINGS_TABLE)?;
|
||||
if let Some(data) = strings_table.get(key)? {
|
||||
let string_value: StringValue = bincode::deserialize(data.value())?;
|
||||
if let Some(expires_at) = string_value.expires_at_ms {
|
||||
if now_in_millis() > expires_at {
|
||||
return Ok(false); // Expired
|
||||
}
|
||||
}
|
||||
if type_val.value() == "string" {
|
||||
let expiration_table = read_txn.open_table(EXPIRATION_TABLE)?;
|
||||
if let Some(expires_at) = expiration_table.get(key)? {
|
||||
if now_in_millis() > expires_at.value() as u128 {
|
||||
return Ok(false); // Expired
|
||||
}
|
||||
}
|
||||
}
|
||||
@@ -878,7 +893,7 @@ impl Storage {
|
||||
None => ListValue { elements: Vec::new() },
|
||||
};
|
||||
|
||||
for element in elements.into_iter().rev() {
|
||||
for element in elements.into_iter() {
|
||||
list_value.elements.insert(0, element);
|
||||
}
|
||||
new_len = list_value.elements.len() as u64;
|
||||
|
Reference in New Issue
Block a user