This commit is contained in:
2025-08-16 07:54:55 +02:00
parent d3e28cafe4
commit 0f6e595000
17 changed files with 2351 additions and 71 deletions

View File

@@ -28,7 +28,11 @@ pub enum Cmd {
HLen(String),
HMGet(String, Vec<String>),
HSetNx(String, String, String),
HScan(String, u64, Option<String>, Option<u64>), // key, cursor, pattern, count
Scan(u64, Option<String>, Option<u64>), // cursor, pattern, count
Ttl(String),
Exists(String),
Quit,
Unknow,
}
@@ -117,7 +121,7 @@ impl Cmd {
}
let mut pairs = Vec::new();
let mut i = 2;
while i < cmd.len() - 1 {
while i + 1 < cmd.len() {
pairs.push((cmd[i].clone(), cmd[i + 1].clone()));
i += 2;
}
@@ -177,6 +181,44 @@ impl Cmd {
}
Cmd::HSetNx(cmd[1].clone(), cmd[2].clone(), cmd[3].clone())
}
"hscan" => {
if cmd.len() < 3 {
return Err(DBError(format!("wrong number of arguments for HSCAN command")));
}
let key = cmd[1].clone();
let cursor = cmd[2].parse::<u64>().map_err(|_|
DBError("ERR invalid cursor".to_string()))?;
let mut pattern = None;
let mut count = None;
let mut i = 3;
while i < cmd.len() {
match cmd[i].to_lowercase().as_str() {
"match" => {
if i + 1 >= cmd.len() {
return Err(DBError("ERR syntax error".to_string()));
}
pattern = Some(cmd[i + 1].clone());
i += 2;
}
"count" => {
if i + 1 >= cmd.len() {
return Err(DBError("ERR syntax error".to_string()));
}
count = Some(cmd[i + 1].parse::<u64>().map_err(|_|
DBError("ERR value is not an integer or out of range".to_string()))?);
i += 2;
}
_ => {
return Err(DBError(format!("ERR syntax error")));
}
}
}
Cmd::HScan(key, cursor, pattern, count)
}
"scan" => {
if cmd.len() < 2 {
return Err(DBError(format!("wrong number of arguments for SCAN command")));
@@ -214,6 +256,24 @@ impl Cmd {
Cmd::Scan(cursor, pattern, count)
}
"ttl" => {
if cmd.len() != 2 {
return Err(DBError(format!("wrong number of arguments for TTL command")));
}
Cmd::Ttl(cmd[1].clone())
}
"exists" => {
if cmd.len() != 2 {
return Err(DBError(format!("wrong number of arguments for EXISTS command")));
}
Cmd::Exists(cmd[1].clone())
}
"quit" => {
if cmd.len() != 1 {
return Err(DBError(format!("wrong number of arguments for QUIT command")));
}
Cmd::Quit
}
_ => Cmd::Unknow,
},
protocol.0,
@@ -282,7 +342,11 @@ impl Cmd {
Cmd::HLen(key) => hlen_cmd(server, key).await,
Cmd::HMGet(key, fields) => hmget_cmd(server, key, fields).await,
Cmd::HSetNx(key, field, value) => hsetnx_cmd(server, key, field, value).await,
Cmd::HScan(key, cursor, pattern, count) => hscan_cmd(server, key, cursor, pattern.as_deref(), count).await,
Cmd::Scan(cursor, pattern, count) => scan_cmd(server, cursor, pattern.as_deref(), count).await,
Cmd::Ttl(key) => ttl_cmd(server, key).await,
Cmd::Exists(key) => exists_cmd(server, key).await,
Cmd::Quit => Ok(Protocol::SimpleString("OK".to_string())),
Cmd::Unknow => Ok(Protocol::err("unknown cmd")),
}
}
@@ -332,7 +396,11 @@ fn config_get_cmd(name: &String, server: &Server) -> Result<Protocol, DBError> {
Protocol::BulkString(name.clone()),
Protocol::BulkString("herodb.redb".to_string()),
])),
_ => Err(DBError(format!("unsupported config {:?}", name))),
"databases" => Ok(Protocol::Array(vec![
Protocol::BulkString(name.clone()),
Protocol::BulkString("16".to_string()),
])),
_ => Ok(Protocol::Array(vec![])), // Return empty array for unknown configs instead of error
}
}
@@ -497,3 +565,31 @@ async fn scan_cmd(server: &Server, cursor: &u64, pattern: Option<&str>, count: &
Err(e) => Ok(Protocol::err(&e.0)),
}
}
async fn hscan_cmd(server: &Server, key: &str, cursor: &u64, pattern: Option<&str>, count: &Option<u64>) -> Result<Protocol, DBError> {
match server.storage.hscan(key, *cursor, pattern, *count) {
Ok((next_cursor, fields)) => {
let mut result = Vec::new();
result.push(Protocol::BulkString(next_cursor.to_string()));
result.push(Protocol::Array(
fields.into_iter().map(Protocol::BulkString).collect(),
));
Ok(Protocol::Array(result))
}
Err(e) => Ok(Protocol::err(&e.0)),
}
}
async fn ttl_cmd(server: &Server, key: &str) -> Result<Protocol, DBError> {
match server.storage.ttl(key) {
Ok(ttl) => Ok(Protocol::SimpleString(ttl.to_string())),
Err(e) => Ok(Protocol::err(&e.0)),
}
}
async fn exists_cmd(server: &Server, key: &str) -> Result<Protocol, DBError> {
match server.storage.exists(key) {
Ok(exists) => Ok(Protocol::SimpleString(if exists { "1" } else { "0" }.to_string())),
Err(e) => Ok(Protocol::err(&e.0)),
}
}

View File

@@ -1,6 +1,6 @@
mod cmd;
pub mod cmd;
pub mod error;
pub mod options;
mod protocol;
pub mod protocol;
pub mod server;
mod storage;
pub mod storage;

View File

@@ -50,6 +50,9 @@ impl Server {
Cmd::from(s).unwrap_or((Cmd::Unknow, Protocol::err("unknow cmd")));
println!("got command: {:?}, protocol: {:?}", cmd, protocol);
// Check if this is a QUIT command before processing
let is_quit = matches!(cmd, Cmd::Quit);
let res = cmd
.run(self, protocol, &mut queued_cmd)
.await
@@ -58,6 +61,12 @@ impl Server {
println!("going to send response {}", res.encode());
_ = stream.write(res.encode().as_bytes()).await?;
// If this was a QUIT command, close the connection
if is_quit {
println!("[handle] QUIT command received, closing connection");
return Ok(());
}
} else {
println!("[handle] going to break");
break;

View File

@@ -506,4 +506,133 @@ impl Storage {
Ok((next_cursor, keys))
}
pub fn hscan(&self, key: &str, cursor: u64, pattern: Option<&str>, count: Option<u64>) -> Result<(u64, Vec<String>), DBError> {
let read_txn = self.db.begin_read()?;
// Check if key exists and is a hash
let types_table = read_txn.open_table(TYPES_TABLE)?;
match types_table.get(key)? {
Some(type_val) if type_val.value() == "hash" => {
let hashes_table = read_txn.open_table(HASHES_TABLE)?;
let count = count.unwrap_or(10);
let mut fields = Vec::new();
let mut current_cursor = 0u64;
let mut returned_fields = 0u64;
let mut iter = hashes_table.iter()?;
while let Some(entry) = iter.next() {
let entry = entry?;
let (hash_key, field) = entry.0.value();
let value = entry.1.value();
if hash_key != key {
continue;
}
// Skip fields until we reach the cursor position
if current_cursor < cursor {
current_cursor += 1;
continue;
}
// Check if field matches pattern
let matches = match pattern {
Some(pat) => {
if pat == "*" {
true
} else if pat.contains('*') {
let pattern_parts: Vec<&str> = pat.split('*').collect();
if pattern_parts.len() == 2 {
let prefix = pattern_parts[0];
let suffix = pattern_parts[1];
field.starts_with(prefix) && field.ends_with(suffix)
} else {
field.contains(&pat.replace('*', ""))
}
} else {
field.contains(pat)
}
}
None => true,
};
if matches {
fields.push(field.to_string());
fields.push(value.to_string());
returned_fields += 1;
if returned_fields >= count {
current_cursor += 1;
break;
}
}
current_cursor += 1;
}
let next_cursor = if returned_fields < count { 0 } else { current_cursor };
Ok((next_cursor, fields))
}
Some(_) => Err(DBError("WRONGTYPE Operation against a key holding the wrong kind of value".to_string())),
None => Ok((0, Vec::new())),
}
}
pub fn ttl(&self, key: &str) -> Result<i64, DBError> {
let read_txn = self.db.begin_read()?;
// Check if key exists
let types_table = read_txn.open_table(TYPES_TABLE)?;
match types_table.get(key)? {
Some(type_val) if type_val.value() == "string" => {
let strings_table = read_txn.open_table(STRINGS_TABLE)?;
match strings_table.get(key)? {
Some(data) => {
let string_value: StringValue = bincode::deserialize(data.value())?;
match string_value.expires_at_ms {
Some(expires_at) => {
let now = now_in_millis();
if now > expires_at {
Ok(-2) // Key expired
} else {
Ok(((expires_at - now) / 1000) as i64) // TTL in seconds
}
}
None => Ok(-1), // No expiration
}
}
None => Ok(-2), // Key doesn't exist
}
}
Some(_) => Ok(-1), // Other types don't have TTL implemented yet
None => Ok(-2), // Key doesn't exist
}
}
pub fn exists(&self, key: &str) -> Result<bool, DBError> {
let read_txn = self.db.begin_read()?;
let types_table = read_txn.open_table(TYPES_TABLE)?;
match types_table.get(key)? {
Some(_) => {
// For string types, check if not expired
if let Some(type_val) = types_table.get(key)? {
if type_val.value() == "string" {
let strings_table = read_txn.open_table(STRINGS_TABLE)?;
if let Some(data) = strings_table.get(key)? {
let string_value: StringValue = bincode::deserialize(data.value())?;
if let Some(expires_at) = string_value.expires_at_ms {
if now_in_millis() > expires_at {
return Ok(false); // Expired
}
}
}
}
}
Ok(true)
}
None => Ok(false),
}
}
}