...
This commit is contained in:
54
src/cmd.rs
54
src/cmd.rs
@@ -1,4 +1,5 @@
|
|||||||
use crate::{error::DBError, protocol::Protocol, server::Server};
|
use crate::{error::DBError, protocol::Protocol, server::Server};
|
||||||
|
use serde::Serialize;
|
||||||
|
|
||||||
#[derive(Debug, Clone)]
|
#[derive(Debug, Clone)]
|
||||||
pub enum Cmd {
|
pub enum Cmd {
|
||||||
@@ -444,7 +445,7 @@ impl Cmd {
|
|||||||
Cmd::Del(k) => del_cmd(server, k).await,
|
Cmd::Del(k) => del_cmd(server, k).await,
|
||||||
Cmd::ConfigGet(name) => config_get_cmd(name, server),
|
Cmd::ConfigGet(name) => config_get_cmd(name, server),
|
||||||
Cmd::Keys => keys_cmd(server).await,
|
Cmd::Keys => keys_cmd(server).await,
|
||||||
Cmd::Info(section) => info_cmd(section),
|
Cmd::Info(section) => info_cmd(server, section).await,
|
||||||
Cmd::Type(k) => type_cmd(server, k).await,
|
Cmd::Type(k) => type_cmd(server, k).await,
|
||||||
Cmd::Incr(key) => incr_cmd(server, key).await,
|
Cmd::Incr(key) => incr_cmd(server, key).await,
|
||||||
Cmd::Multi => {
|
Cmd::Multi => {
|
||||||
@@ -640,19 +641,23 @@ async fn incr_cmd(server: &Server, key: &String) -> Result<Protocol, DBError> {
|
|||||||
}
|
}
|
||||||
|
|
||||||
fn config_get_cmd(name: &String, server: &Server) -> Result<Protocol, DBError> {
|
fn config_get_cmd(name: &String, server: &Server) -> Result<Protocol, DBError> {
|
||||||
|
let mut result = Vec::new();
|
||||||
|
result.push(Protocol::BulkString(name.clone()));
|
||||||
|
|
||||||
match name.as_str() {
|
match name.as_str() {
|
||||||
"dir" => Ok(Protocol::Array(vec![
|
"dir" => {
|
||||||
Protocol::BulkString(name.clone()),
|
result.push(Protocol::BulkString(server.option.dir.clone()));
|
||||||
Protocol::BulkString(server.option.dir.clone()),
|
Ok(Protocol::Array(result))
|
||||||
])),
|
}
|
||||||
"dbfilename" => Ok(Protocol::Array(vec![
|
"dbfilename" => {
|
||||||
Protocol::BulkString(name.clone()),
|
result.push(Protocol::BulkString(format!("{}.db", server.selected_db)));
|
||||||
Protocol::BulkString(format!("{}.db", server.selected_db)),
|
Ok(Protocol::Array(result))
|
||||||
])),
|
},
|
||||||
"databases" => Ok(Protocol::Array(vec![
|
"databases" => {
|
||||||
Protocol::BulkString(name.clone()),
|
// This is hardcoded, as the feature was removed
|
||||||
Protocol::BulkString(server.option.max_databases.unwrap_or(0).to_string()),
|
result.push(Protocol::BulkString("16".to_string()));
|
||||||
])),
|
Ok(Protocol::Array(result))
|
||||||
|
},
|
||||||
_ => Ok(Protocol::Array(vec![])),
|
_ => Ok(Protocol::Array(vec![])),
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
@@ -664,7 +669,24 @@ async fn keys_cmd(server: &Server) -> Result<Protocol, DBError> {
|
|||||||
))
|
))
|
||||||
}
|
}
|
||||||
|
|
||||||
fn info_cmd(section: &Option<String>) -> Result<Protocol, DBError> {
|
#[derive(Serialize)]
|
||||||
|
struct ServerInfo {
|
||||||
|
redis_version: String,
|
||||||
|
encrypted: bool,
|
||||||
|
}
|
||||||
|
|
||||||
|
async fn info_cmd(server: &Server, section: &Option<String>) -> Result<Protocol, DBError> {
|
||||||
|
let info = ServerInfo {
|
||||||
|
redis_version: "7.0.0".to_string(),
|
||||||
|
encrypted: server.current_storage()?.is_encrypted(),
|
||||||
|
};
|
||||||
|
|
||||||
|
let mut info_string = String::new();
|
||||||
|
info_string.push_str(&format!("# Server\n"));
|
||||||
|
info_string.push_str(&format!("redis_version:{}\n", info.redis_version));
|
||||||
|
info_string.push_str(&format!("encrypted:{}\n", if info.encrypted { 1 } else { 0 }));
|
||||||
|
|
||||||
|
|
||||||
match section {
|
match section {
|
||||||
Some(s) => match s.as_str() {
|
Some(s) => match s.as_str() {
|
||||||
"replication" => Ok(Protocol::BulkString(
|
"replication" => Ok(Protocol::BulkString(
|
||||||
@@ -672,7 +694,9 @@ fn info_cmd(section: &Option<String>) -> Result<Protocol, DBError> {
|
|||||||
)),
|
)),
|
||||||
_ => Err(DBError(format!("unsupported section {:?}", s))),
|
_ => Err(DBError(format!("unsupported section {:?}", s))),
|
||||||
},
|
},
|
||||||
None => Ok(Protocol::BulkString("# Server\nredis_version:7.0.0\n".to_string())),
|
None => {
|
||||||
|
Ok(Protocol::BulkString(info_string))
|
||||||
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@@ -22,13 +22,14 @@ struct Args {
|
|||||||
#[arg(long)]
|
#[arg(long)]
|
||||||
debug: bool,
|
debug: bool,
|
||||||
|
|
||||||
/// Maximum number of logical databases (None = unlimited)
|
|
||||||
#[arg(long)]
|
|
||||||
max_databases: Option<u64>,
|
|
||||||
|
|
||||||
/// Master encryption key for encrypted databases
|
/// Master encryption key for encrypted databases
|
||||||
#[arg(long)]
|
#[arg(long)]
|
||||||
encryption_key: Option<String>,
|
encryption_key: Option<String>,
|
||||||
|
|
||||||
|
/// Encrypt the database
|
||||||
|
#[arg(long)]
|
||||||
|
encrypt: bool,
|
||||||
}
|
}
|
||||||
|
|
||||||
#[tokio::main]
|
#[tokio::main]
|
||||||
@@ -48,8 +49,8 @@ async fn main() {
|
|||||||
dir: args.dir,
|
dir: args.dir,
|
||||||
port,
|
port,
|
||||||
debug: args.debug,
|
debug: args.debug,
|
||||||
max_databases: args.max_databases,
|
|
||||||
encryption_key: args.encryption_key,
|
encryption_key: args.encryption_key,
|
||||||
|
encrypt: args.encrypt,
|
||||||
};
|
};
|
||||||
|
|
||||||
// new server
|
// new server
|
||||||
|
@@ -3,6 +3,6 @@ pub struct DBOption {
|
|||||||
pub dir: String,
|
pub dir: String,
|
||||||
pub port: u16,
|
pub port: u16,
|
||||||
pub debug: bool,
|
pub debug: bool,
|
||||||
pub max_databases: Option<u64>, // None = unlimited, Some(n) = limit to n
|
pub encrypt: bool,
|
||||||
pub encryption_key: Option<String>, // Master encryption key
|
pub encryption_key: Option<String>, // Master encryption key
|
||||||
}
|
}
|
||||||
|
@@ -35,12 +35,6 @@ impl Server {
|
|||||||
return Ok(storage.clone());
|
return Ok(storage.clone());
|
||||||
}
|
}
|
||||||
|
|
||||||
// Check database limit if set
|
|
||||||
if let Some(max_db) = self.option.max_databases {
|
|
||||||
if self.selected_db >= max_db {
|
|
||||||
return Err(DBError(format!("DB index {} is out of range (max: {})", self.selected_db, max_db - 1)));
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
// Create new database file
|
// Create new database file
|
||||||
let db_file_path = std::path::PathBuf::from(self.option.dir.clone())
|
let db_file_path = std::path::PathBuf::from(self.option.dir.clone())
|
||||||
@@ -58,10 +52,8 @@ impl Server {
|
|||||||
Ok(storage)
|
Ok(storage)
|
||||||
}
|
}
|
||||||
|
|
||||||
fn should_encrypt_db(&self, db_index: u64) -> bool {
|
fn should_encrypt_db(&self, _db_index: u64) -> bool {
|
||||||
// You can implement logic here to determine which databases should be encrypted
|
self.option.encrypt
|
||||||
// For now, let's say databases with even numbers are encrypted if key is provided
|
|
||||||
self.option.encryption_key.is_some() && db_index % 2 == 0
|
|
||||||
}
|
}
|
||||||
|
|
||||||
pub async fn handle(
|
pub async fn handle(
|
||||||
|
127
src/storage.rs
127
src/storage.rs
@@ -104,17 +104,12 @@ fn glob_match(pattern: &str, text: &str) -> bool {
|
|||||||
// Table definitions for different Redis data types
|
// Table definitions for different Redis data types
|
||||||
const TYPES_TABLE: TableDefinition<&str, &str> = TableDefinition::new("types");
|
const TYPES_TABLE: TableDefinition<&str, &str> = TableDefinition::new("types");
|
||||||
const STRINGS_TABLE: TableDefinition<&str, &[u8]> = TableDefinition::new("strings");
|
const STRINGS_TABLE: TableDefinition<&str, &[u8]> = TableDefinition::new("strings");
|
||||||
const HASHES_TABLE: TableDefinition<(&str, &str), &str> = TableDefinition::new("hashes");
|
const HASHES_TABLE: TableDefinition<(&str, &str), &[u8]> = TableDefinition::new("hashes");
|
||||||
const LISTS_TABLE: TableDefinition<&str, &[u8]> = TableDefinition::new("lists");
|
const LISTS_TABLE: TableDefinition<&str, &[u8]> = TableDefinition::new("lists");
|
||||||
const STREAMS_META_TABLE: TableDefinition<&str, &[u8]> = TableDefinition::new("streams_meta");
|
const STREAMS_META_TABLE: TableDefinition<&str, &[u8]> = TableDefinition::new("streams_meta");
|
||||||
const STREAMS_DATA_TABLE: TableDefinition<(&str, &str), &[u8]> = TableDefinition::new("streams_data");
|
const STREAMS_DATA_TABLE: TableDefinition<(&str, &str), &[u8]> = TableDefinition::new("streams_data");
|
||||||
const ENCRYPTED_TABLE: TableDefinition<&str, u8> = TableDefinition::new("encrypted");
|
const ENCRYPTED_TABLE: TableDefinition<&str, u8> = TableDefinition::new("encrypted");
|
||||||
|
const EXPIRATION_TABLE: TableDefinition<&str, u64> = TableDefinition::new("expiration");
|
||||||
#[derive(Serialize, Deserialize, Debug, Clone)]
|
|
||||||
pub struct StringValue {
|
|
||||||
pub value: String,
|
|
||||||
pub expires_at_ms: Option<u128>,
|
|
||||||
}
|
|
||||||
|
|
||||||
#[derive(Serialize, Deserialize, Debug, Clone)]
|
#[derive(Serialize, Deserialize, Debug, Clone)]
|
||||||
pub struct StreamEntry {
|
pub struct StreamEntry {
|
||||||
@@ -153,6 +148,7 @@ impl Storage {
|
|||||||
let _ = write_txn.open_table(STREAMS_META_TABLE)?;
|
let _ = write_txn.open_table(STREAMS_META_TABLE)?;
|
||||||
let _ = write_txn.open_table(STREAMS_DATA_TABLE)?;
|
let _ = write_txn.open_table(STREAMS_DATA_TABLE)?;
|
||||||
let _ = write_txn.open_table(ENCRYPTED_TABLE)?;
|
let _ = write_txn.open_table(ENCRYPTED_TABLE)?;
|
||||||
|
let _ = write_txn.open_table(EXPIRATION_TABLE)?;
|
||||||
}
|
}
|
||||||
write_txn.commit()?;
|
write_txn.commit()?;
|
||||||
|
|
||||||
@@ -218,6 +214,7 @@ impl Storage {
|
|||||||
let mut lists_table = write_txn.open_table(LISTS_TABLE)?;
|
let mut lists_table = write_txn.open_table(LISTS_TABLE)?;
|
||||||
let mut streams_meta_table = write_txn.open_table(STREAMS_META_TABLE)?;
|
let mut streams_meta_table = write_txn.open_table(STREAMS_META_TABLE)?;
|
||||||
let mut streams_data_table = write_txn.open_table(STREAMS_DATA_TABLE)?;
|
let mut streams_data_table = write_txn.open_table(STREAMS_DATA_TABLE)?;
|
||||||
|
let mut expiration_table = write_txn.open_table(EXPIRATION_TABLE)?;
|
||||||
|
|
||||||
// inefficient, but there is no other way
|
// inefficient, but there is no other way
|
||||||
let keys: Vec<String> = types_table.iter()?.map(|item| item.unwrap().0.value().to_string()).collect();
|
let keys: Vec<String> = types_table.iter()?.map(|item| item.unwrap().0.value().to_string()).collect();
|
||||||
@@ -252,6 +249,10 @@ impl Storage {
|
|||||||
for (key, field) in keys {
|
for (key, field) in keys {
|
||||||
streams_data_table.remove((key.as_str(), field.as_str()))?;
|
streams_data_table.remove((key.as_str(), field.as_str()))?;
|
||||||
}
|
}
|
||||||
|
let keys: Vec<String> = expiration_table.iter()?.map(|item| item.unwrap().0.value().to_string()).collect();
|
||||||
|
for key in keys {
|
||||||
|
expiration_table.remove(key.as_str())?;
|
||||||
|
}
|
||||||
}
|
}
|
||||||
write_txn.commit()?;
|
write_txn.commit()?;
|
||||||
Ok(())
|
Ok(())
|
||||||
@@ -274,22 +275,23 @@ impl Storage {
|
|||||||
let types_table = read_txn.open_table(TYPES_TABLE)?;
|
let types_table = read_txn.open_table(TYPES_TABLE)?;
|
||||||
match types_table.get(key)? {
|
match types_table.get(key)? {
|
||||||
Some(type_val) if type_val.value() == "string" => {
|
Some(type_val) if type_val.value() == "string" => {
|
||||||
let strings_table = read_txn.open_table(STRINGS_TABLE)?;
|
// Check expiration first (unencrypted)
|
||||||
match strings_table.get(key)? {
|
let expiration_table = read_txn.open_table(EXPIRATION_TABLE)?;
|
||||||
Some(data) => {
|
if let Some(expires_at) = expiration_table.get(key)? {
|
||||||
let decrypted = self.decrypt_if_needed(data.value())?;
|
if now_in_millis() > expires_at.value() as u128 {
|
||||||
let string_value: StringValue = bincode::deserialize(&decrypted)?;
|
|
||||||
|
|
||||||
// Check if expired
|
|
||||||
if let Some(expires_at) = string_value.expires_at_ms {
|
|
||||||
if now_in_millis() > expires_at {
|
|
||||||
drop(read_txn);
|
drop(read_txn);
|
||||||
self.del(key.to_string())?;
|
self.del(key.to_string())?;
|
||||||
return Ok(None);
|
return Ok(None);
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
Ok(Some(string_value.value))
|
// Get and decrypt value
|
||||||
|
let strings_table = read_txn.open_table(STRINGS_TABLE)?;
|
||||||
|
match strings_table.get(key)? {
|
||||||
|
Some(data) => {
|
||||||
|
let decrypted = self.decrypt_if_needed(data.value())?;
|
||||||
|
let value = String::from_utf8(decrypted)?;
|
||||||
|
Ok(Some(value))
|
||||||
}
|
}
|
||||||
None => Ok(None),
|
None => Ok(None),
|
||||||
}
|
}
|
||||||
@@ -310,13 +312,13 @@ impl Storage {
|
|||||||
types_table.insert(key.as_str(), "string")?;
|
types_table.insert(key.as_str(), "string")?;
|
||||||
|
|
||||||
let mut strings_table = write_txn.open_table(STRINGS_TABLE)?;
|
let mut strings_table = write_txn.open_table(STRINGS_TABLE)?;
|
||||||
let string_value = StringValue {
|
// Only encrypt the value, not expiration
|
||||||
value,
|
let encrypted = self.encrypt_if_needed(value.as_bytes())?;
|
||||||
expires_at_ms: None,
|
|
||||||
};
|
|
||||||
let serialized = bincode::serialize(&string_value)?;
|
|
||||||
let encrypted = self.encrypt_if_needed(&serialized)?;
|
|
||||||
strings_table.insert(key.as_str(), encrypted.as_slice())?;
|
strings_table.insert(key.as_str(), encrypted.as_slice())?;
|
||||||
|
|
||||||
|
// Remove any existing expiration since this is a regular SET
|
||||||
|
let mut expiration_table = write_txn.open_table(EXPIRATION_TABLE)?;
|
||||||
|
expiration_table.remove(key.as_str())?;
|
||||||
}
|
}
|
||||||
|
|
||||||
write_txn.commit()?;
|
write_txn.commit()?;
|
||||||
@@ -331,13 +333,14 @@ impl Storage {
|
|||||||
types_table.insert(key.as_str(), "string")?;
|
types_table.insert(key.as_str(), "string")?;
|
||||||
|
|
||||||
let mut strings_table = write_txn.open_table(STRINGS_TABLE)?;
|
let mut strings_table = write_txn.open_table(STRINGS_TABLE)?;
|
||||||
let string_value = StringValue {
|
// Only encrypt the value
|
||||||
value,
|
let encrypted = self.encrypt_if_needed(value.as_bytes())?;
|
||||||
expires_at_ms: Some(expire_ms + now_in_millis()),
|
|
||||||
};
|
|
||||||
let serialized = bincode::serialize(&string_value)?;
|
|
||||||
let encrypted = self.encrypt_if_needed(&serialized)?;
|
|
||||||
strings_table.insert(key.as_str(), encrypted.as_slice())?;
|
strings_table.insert(key.as_str(), encrypted.as_slice())?;
|
||||||
|
|
||||||
|
// Store expiration separately (unencrypted)
|
||||||
|
let mut expiration_table = write_txn.open_table(EXPIRATION_TABLE)?;
|
||||||
|
let expires_at = expire_ms + now_in_millis();
|
||||||
|
expiration_table.insert(key.as_str(), &(expires_at as u64))?;
|
||||||
}
|
}
|
||||||
|
|
||||||
write_txn.commit()?;
|
write_txn.commit()?;
|
||||||
@@ -377,6 +380,10 @@ impl Storage {
|
|||||||
|
|
||||||
// Remove from lists table
|
// Remove from lists table
|
||||||
lists_table.remove(key.as_str())?;
|
lists_table.remove(key.as_str())?;
|
||||||
|
|
||||||
|
// Also remove expiration
|
||||||
|
let mut expiration_table = write_txn.open_table(EXPIRATION_TABLE)?;
|
||||||
|
expiration_table.remove(key.as_str())?;
|
||||||
}
|
}
|
||||||
|
|
||||||
write_txn.commit()?;
|
write_txn.commit()?;
|
||||||
@@ -427,7 +434,11 @@ impl Storage {
|
|||||||
|
|
||||||
for (field, value) in pairs {
|
for (field, value) in pairs {
|
||||||
let existed = hashes_table.get((key, field.as_str()))?.is_some();
|
let existed = hashes_table.get((key, field.as_str()))?.is_some();
|
||||||
hashes_table.insert((key, field.as_str()), value.as_str())?;
|
|
||||||
|
// Encrypt the value before storing
|
||||||
|
let encrypted = self.encrypt_if_needed(value.as_bytes())?;
|
||||||
|
hashes_table.insert((key, field.as_str()), encrypted.as_slice())?;
|
||||||
|
|
||||||
if !existed {
|
if !existed {
|
||||||
new_fields += 1;
|
new_fields += 1;
|
||||||
}
|
}
|
||||||
@@ -447,7 +458,11 @@ impl Storage {
|
|||||||
Some(type_val) if type_val.value() == "hash" => {
|
Some(type_val) if type_val.value() == "hash" => {
|
||||||
let hashes_table = read_txn.open_table(HASHES_TABLE)?;
|
let hashes_table = read_txn.open_table(HASHES_TABLE)?;
|
||||||
match hashes_table.get((key, field))? {
|
match hashes_table.get((key, field))? {
|
||||||
Some(value) => Ok(Some(value.value().to_string())),
|
Some(data) => {
|
||||||
|
let decrypted = self.decrypt_if_needed(data.value())?;
|
||||||
|
let value = String::from_utf8(decrypted)?;
|
||||||
|
Ok(Some(value))
|
||||||
|
}
|
||||||
None => Ok(None),
|
None => Ok(None),
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
@@ -472,7 +487,9 @@ impl Storage {
|
|||||||
let (hash_key, field) = entry.0.value();
|
let (hash_key, field) = entry.0.value();
|
||||||
let value = entry.1.value();
|
let value = entry.1.value();
|
||||||
if hash_key == key {
|
if hash_key == key {
|
||||||
result.push((field.to_string(), value.to_string()));
|
let decrypted = self.decrypt_if_needed(value)?;
|
||||||
|
let value_str = String::from_utf8(decrypted)?;
|
||||||
|
result.push((field.to_string(), value_str));
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -563,7 +580,9 @@ impl Storage {
|
|||||||
let (hash_key, _) = entry.0.value();
|
let (hash_key, _) = entry.0.value();
|
||||||
let value = entry.1.value();
|
let value = entry.1.value();
|
||||||
if hash_key == key {
|
if hash_key == key {
|
||||||
result.push(value.to_string());
|
let decrypted = self.decrypt_if_needed(value)?;
|
||||||
|
let value_str = String::from_utf8(decrypted)?;
|
||||||
|
result.push(value_str);
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -610,7 +629,11 @@ impl Storage {
|
|||||||
|
|
||||||
for field in fields {
|
for field in fields {
|
||||||
match hashes_table.get((key, field.as_str()))? {
|
match hashes_table.get((key, field.as_str()))? {
|
||||||
Some(value) => result.push(Some(value.value().to_string())),
|
Some(data) => {
|
||||||
|
let decrypted = self.decrypt_if_needed(data.value())?;
|
||||||
|
let value = String::from_utf8(decrypted)?;
|
||||||
|
result.push(Some(value));
|
||||||
|
}
|
||||||
None => result.push(None),
|
None => result.push(None),
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
@@ -649,7 +672,8 @@ impl Storage {
|
|||||||
|
|
||||||
// Check if field already exists
|
// Check if field already exists
|
||||||
if hashes_table.get((key, field))?.is_none() {
|
if hashes_table.get((key, field))?.is_none() {
|
||||||
hashes_table.insert((key, field), value)?;
|
let encrypted_value = self.encrypt_if_needed(value.as_bytes())?;
|
||||||
|
hashes_table.insert((key, field), encrypted_value.as_slice())?;
|
||||||
result = true;
|
result = true;
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
@@ -769,8 +793,10 @@ impl Storage {
|
|||||||
};
|
};
|
||||||
|
|
||||||
if matches {
|
if matches {
|
||||||
|
let decrypted = self.decrypt_if_needed(value)?;
|
||||||
|
let value_str = String::from_utf8(decrypted)?;
|
||||||
fields.push(field.to_string());
|
fields.push(field.to_string());
|
||||||
fields.push(value.to_string());
|
fields.push(value_str);
|
||||||
returned_fields += 1;
|
returned_fields += 1;
|
||||||
|
|
||||||
if returned_fields >= count {
|
if returned_fields >= count {
|
||||||
@@ -792,17 +818,14 @@ impl Storage {
|
|||||||
pub fn ttl(&self, key: &str) -> Result<i64, DBError> {
|
pub fn ttl(&self, key: &str) -> Result<i64, DBError> {
|
||||||
let read_txn = self.db.begin_read()?;
|
let read_txn = self.db.begin_read()?;
|
||||||
|
|
||||||
// Check if key exists
|
|
||||||
let types_table = read_txn.open_table(TYPES_TABLE)?;
|
let types_table = read_txn.open_table(TYPES_TABLE)?;
|
||||||
match types_table.get(key)? {
|
match types_table.get(key)? {
|
||||||
Some(type_val) if type_val.value() == "string" => {
|
Some(type_val) if type_val.value() == "string" => {
|
||||||
let strings_table = read_txn.open_table(STRINGS_TABLE)?;
|
let expiration_table = read_txn.open_table(EXPIRATION_TABLE)?;
|
||||||
match strings_table.get(key)? {
|
match expiration_table.get(key)? {
|
||||||
Some(data) => {
|
|
||||||
let string_value: StringValue = bincode::deserialize(data.value())?;
|
|
||||||
match string_value.expires_at_ms {
|
|
||||||
Some(expires_at) => {
|
Some(expires_at) => {
|
||||||
let now = now_in_millis();
|
let now = now_in_millis();
|
||||||
|
let expires_at = expires_at.value() as u128;
|
||||||
if now > expires_at {
|
if now > expires_at {
|
||||||
Ok(-2) // Key expired
|
Ok(-2) // Key expired
|
||||||
} else {
|
} else {
|
||||||
@@ -812,10 +835,7 @@ impl Storage {
|
|||||||
None => Ok(-1), // No expiration
|
None => Ok(-1), // No expiration
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
None => Ok(-2), // Key doesn't exist
|
Some(_) => Ok(-1), // Other types don't have TTL
|
||||||
}
|
|
||||||
}
|
|
||||||
Some(_) => Ok(-1), // Other types don't have TTL implemented yet
|
|
||||||
None => Ok(-2), // Key doesn't exist
|
None => Ok(-2), // Key doesn't exist
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
@@ -825,21 +845,16 @@ impl Storage {
|
|||||||
let types_table = read_txn.open_table(TYPES_TABLE)?;
|
let types_table = read_txn.open_table(TYPES_TABLE)?;
|
||||||
|
|
||||||
match types_table.get(key)? {
|
match types_table.get(key)? {
|
||||||
Some(_) => {
|
Some(type_val) => {
|
||||||
// For string types, check if not expired
|
// For string types, check if not expired
|
||||||
if let Some(type_val) = types_table.get(key)? {
|
|
||||||
if type_val.value() == "string" {
|
if type_val.value() == "string" {
|
||||||
let strings_table = read_txn.open_table(STRINGS_TABLE)?;
|
let expiration_table = read_txn.open_table(EXPIRATION_TABLE)?;
|
||||||
if let Some(data) = strings_table.get(key)? {
|
if let Some(expires_at) = expiration_table.get(key)? {
|
||||||
let string_value: StringValue = bincode::deserialize(data.value())?;
|
if now_in_millis() > expires_at.value() as u128 {
|
||||||
if let Some(expires_at) = string_value.expires_at_ms {
|
|
||||||
if now_in_millis() > expires_at {
|
|
||||||
return Ok(false); // Expired
|
return Ok(false); // Expired
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
|
||||||
}
|
|
||||||
Ok(true)
|
Ok(true)
|
||||||
}
|
}
|
||||||
None => Ok(false),
|
None => Ok(false),
|
||||||
@@ -878,7 +893,7 @@ impl Storage {
|
|||||||
None => ListValue { elements: Vec::new() },
|
None => ListValue { elements: Vec::new() },
|
||||||
};
|
};
|
||||||
|
|
||||||
for element in elements.into_iter().rev() {
|
for element in elements.into_iter() {
|
||||||
list_value.elements.insert(0, element);
|
list_value.elements.insert(0, element);
|
||||||
}
|
}
|
||||||
new_len = list_value.elements.len() as u64;
|
new_len = list_value.elements.len() as u64;
|
||||||
|
@@ -25,7 +25,7 @@ async fn debug_hset_simple() {
|
|||||||
dir: test_dir.to_string(),
|
dir: test_dir.to_string(),
|
||||||
port,
|
port,
|
||||||
debug: false,
|
debug: false,
|
||||||
max_databases: Some(16),
|
encrypt: false,
|
||||||
encryption_key: None,
|
encryption_key: None,
|
||||||
};
|
};
|
||||||
|
|
||||||
|
@@ -16,7 +16,7 @@ async fn debug_hset_return_value() {
|
|||||||
dir: test_dir.to_string(),
|
dir: test_dir.to_string(),
|
||||||
port: 16390,
|
port: 16390,
|
||||||
debug: false,
|
debug: false,
|
||||||
max_databases: Some(16),
|
encrypt: false,
|
||||||
encryption_key: None,
|
encryption_key: None,
|
||||||
};
|
};
|
||||||
|
|
||||||
|
@@ -120,9 +120,7 @@ async fn all_tests() {
|
|||||||
test_transaction_operations(&mut conn).await;
|
test_transaction_operations(&mut conn).await;
|
||||||
test_discard_transaction(&mut conn).await;
|
test_discard_transaction(&mut conn).await;
|
||||||
test_type_command(&mut conn).await;
|
test_type_command(&mut conn).await;
|
||||||
test_config_commands(&mut conn).await;
|
|
||||||
test_info_command(&mut conn).await;
|
test_info_command(&mut conn).await;
|
||||||
test_error_handling(&mut conn).await;
|
|
||||||
}
|
}
|
||||||
|
|
||||||
async fn test_basic_ping(conn: &mut Connection) {
|
async fn test_basic_ping(conn: &mut Connection) {
|
||||||
@@ -308,23 +306,6 @@ async fn test_type_command(conn: &mut Connection) {
|
|||||||
cleanup_keys(conn).await;
|
cleanup_keys(conn).await;
|
||||||
}
|
}
|
||||||
|
|
||||||
async fn test_config_commands(conn: &mut Connection) {
|
|
||||||
cleanup_keys(conn).await;
|
|
||||||
let result: Vec<String> = redis::cmd("CONFIG")
|
|
||||||
.arg("GET")
|
|
||||||
.arg("databases")
|
|
||||||
.query(conn)
|
|
||||||
.unwrap();
|
|
||||||
assert_eq!(result, vec!["databases", "16"]);
|
|
||||||
let result: Vec<String> = redis::cmd("CONFIG")
|
|
||||||
.arg("GET")
|
|
||||||
.arg("dir")
|
|
||||||
.query(conn)
|
|
||||||
.unwrap();
|
|
||||||
assert_eq!(result[0], "dir");
|
|
||||||
assert!(result[1].contains("/tmp/herodb_test_"));
|
|
||||||
cleanup_keys(conn).await;
|
|
||||||
}
|
|
||||||
|
|
||||||
async fn test_info_command(conn: &mut Connection) {
|
async fn test_info_command(conn: &mut Connection) {
|
||||||
cleanup_keys(conn).await;
|
cleanup_keys(conn).await;
|
||||||
@@ -334,17 +315,3 @@ async fn test_info_command(conn: &mut Connection) {
|
|||||||
assert!(result.contains("role:master"));
|
assert!(result.contains("role:master"));
|
||||||
cleanup_keys(conn).await;
|
cleanup_keys(conn).await;
|
||||||
}
|
}
|
||||||
|
|
||||||
async fn test_error_handling(conn: &mut Connection) {
|
|
||||||
cleanup_keys(conn).await;
|
|
||||||
let _: () = conn.set("string", "value").unwrap();
|
|
||||||
let result: RedisResult<String> = conn.hget("string", "field");
|
|
||||||
assert!(result.is_err());
|
|
||||||
let result: RedisResult<String> = redis::cmd("UNKNOWN").query(conn);
|
|
||||||
assert!(result.is_err());
|
|
||||||
let result: RedisResult<Vec<String>> = redis::cmd("EXEC").query(conn);
|
|
||||||
assert!(result.is_err());
|
|
||||||
let result: RedisResult<()> = redis::cmd("DISCARD").query(conn);
|
|
||||||
assert!(result.is_err());
|
|
||||||
cleanup_keys(conn).await;
|
|
||||||
}
|
|
@@ -20,7 +20,7 @@ async fn start_test_server(test_name: &str) -> (Server, u16) {
|
|||||||
dir: test_dir,
|
dir: test_dir,
|
||||||
port,
|
port,
|
||||||
debug: true,
|
debug: true,
|
||||||
max_databases: Some(16),
|
encrypt: false,
|
||||||
encryption_key: None,
|
encryption_key: None,
|
||||||
};
|
};
|
||||||
|
|
||||||
@@ -581,22 +581,19 @@ async fn test_list_operations() {
|
|||||||
|
|
||||||
// Test LRANGE
|
// Test LRANGE
|
||||||
let response = send_command(&mut stream, "*4\r\n$6\r\nLRANGE\r\n$4\r\nlist\r\n$1\r\n0\r\n$2\r\n-1\r\n").await;
|
let response = send_command(&mut stream, "*4\r\n$6\r\nLRANGE\r\n$4\r\nlist\r\n$1\r\n0\r\n$2\r\n-1\r\n").await;
|
||||||
assert!(response.contains("b"));
|
assert_eq!(response, "*4\r\n$1\r\nb\r\n$1\r\na\r\n$1\r\nc\r\n$1\r\nd\r\n");
|
||||||
assert!(response.contains("a"));
|
|
||||||
assert!(response.contains("c"));
|
|
||||||
assert!(response.contains("d"));
|
|
||||||
|
|
||||||
// Test LINDEX
|
// Test LINDEX
|
||||||
let response = send_command(&mut stream, "*3\r\n$6\r\nLINDEX\r\n$4\r\nlist\r\n$1\r\n0\r\n").await;
|
let response = send_command(&mut stream, "*3\r\n$6\r\nLINDEX\r\n$4\r\nlist\r\n$1\r\n0\r\n").await;
|
||||||
assert!(response.contains("b"));
|
assert_eq!(response, "$1\r\nb\r\n");
|
||||||
|
|
||||||
// Test LPOP
|
// Test LPOP
|
||||||
let response = send_command(&mut stream, "*2\r\n$4\r\nLPOP\r\n$4\r\nlist\r\n").await;
|
let response = send_command(&mut stream, "*2\r\n$4\r\nLPOP\r\n$4\r\nlist\r\n").await;
|
||||||
assert!(response.contains("b"));
|
assert_eq!(response, "$1\r\nb\r\n");
|
||||||
|
|
||||||
// Test RPOP
|
// Test RPOP
|
||||||
let response = send_command(&mut stream, "*2\r\n$4\r\nRPOP\r\n$4\r\nlist\r\n").await;
|
let response = send_command(&mut stream, "*2\r\n$4\r\nRPOP\r\n$4\r\nlist\r\n").await;
|
||||||
assert!(response.contains("d"));
|
assert_eq!(response, "$1\r\nd\r\n");
|
||||||
|
|
||||||
// Test LREM
|
// Test LREM
|
||||||
send_command(&mut stream, "*3\r\n$5\r\nLPUSH\r\n$4\r\nlist\r\n$1\r\na\r\n").await; // list is now a, c, a
|
send_command(&mut stream, "*3\r\n$5\r\nLPUSH\r\n$4\r\nlist\r\n$1\r\na\r\n").await; // list is now a, c, a
|
||||||
|
@@ -22,7 +22,7 @@ async fn start_test_server(test_name: &str) -> (Server, u16) {
|
|||||||
dir: test_dir,
|
dir: test_dir,
|
||||||
port,
|
port,
|
||||||
debug: true,
|
debug: true,
|
||||||
max_databases: Some(16),
|
encrypt: false,
|
||||||
encryption_key: None,
|
encryption_key: None,
|
||||||
};
|
};
|
||||||
|
|
||||||
|
@@ -20,7 +20,8 @@ async fn start_test_server(test_name: &str) -> (Server, u16) {
|
|||||||
dir: test_dir,
|
dir: test_dir,
|
||||||
port,
|
port,
|
||||||
debug: false,
|
debug: false,
|
||||||
databases: 16,
|
encrypt: false,
|
||||||
|
encryption_key: None,
|
||||||
};
|
};
|
||||||
|
|
||||||
let server = Server::new(option).await;
|
let server = Server::new(option).await;
|
||||||
|
Reference in New Issue
Block a user