...
This commit is contained in:
100
src/cmd.rs
100
src/cmd.rs
@@ -28,7 +28,11 @@ pub enum Cmd {
|
||||
HLen(String),
|
||||
HMGet(String, Vec<String>),
|
||||
HSetNx(String, String, String),
|
||||
HScan(String, u64, Option<String>, Option<u64>), // key, cursor, pattern, count
|
||||
Scan(u64, Option<String>, Option<u64>), // cursor, pattern, count
|
||||
Ttl(String),
|
||||
Exists(String),
|
||||
Quit,
|
||||
Unknow,
|
||||
}
|
||||
|
||||
@@ -117,7 +121,7 @@ impl Cmd {
|
||||
}
|
||||
let mut pairs = Vec::new();
|
||||
let mut i = 2;
|
||||
while i < cmd.len() - 1 {
|
||||
while i + 1 < cmd.len() {
|
||||
pairs.push((cmd[i].clone(), cmd[i + 1].clone()));
|
||||
i += 2;
|
||||
}
|
||||
@@ -177,6 +181,44 @@ impl Cmd {
|
||||
}
|
||||
Cmd::HSetNx(cmd[1].clone(), cmd[2].clone(), cmd[3].clone())
|
||||
}
|
||||
"hscan" => {
|
||||
if cmd.len() < 3 {
|
||||
return Err(DBError(format!("wrong number of arguments for HSCAN command")));
|
||||
}
|
||||
|
||||
let key = cmd[1].clone();
|
||||
let cursor = cmd[2].parse::<u64>().map_err(|_|
|
||||
DBError("ERR invalid cursor".to_string()))?;
|
||||
|
||||
let mut pattern = None;
|
||||
let mut count = None;
|
||||
let mut i = 3;
|
||||
|
||||
while i < cmd.len() {
|
||||
match cmd[i].to_lowercase().as_str() {
|
||||
"match" => {
|
||||
if i + 1 >= cmd.len() {
|
||||
return Err(DBError("ERR syntax error".to_string()));
|
||||
}
|
||||
pattern = Some(cmd[i + 1].clone());
|
||||
i += 2;
|
||||
}
|
||||
"count" => {
|
||||
if i + 1 >= cmd.len() {
|
||||
return Err(DBError("ERR syntax error".to_string()));
|
||||
}
|
||||
count = Some(cmd[i + 1].parse::<u64>().map_err(|_|
|
||||
DBError("ERR value is not an integer or out of range".to_string()))?);
|
||||
i += 2;
|
||||
}
|
||||
_ => {
|
||||
return Err(DBError(format!("ERR syntax error")));
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
Cmd::HScan(key, cursor, pattern, count)
|
||||
}
|
||||
"scan" => {
|
||||
if cmd.len() < 2 {
|
||||
return Err(DBError(format!("wrong number of arguments for SCAN command")));
|
||||
@@ -214,6 +256,24 @@ impl Cmd {
|
||||
|
||||
Cmd::Scan(cursor, pattern, count)
|
||||
}
|
||||
"ttl" => {
|
||||
if cmd.len() != 2 {
|
||||
return Err(DBError(format!("wrong number of arguments for TTL command")));
|
||||
}
|
||||
Cmd::Ttl(cmd[1].clone())
|
||||
}
|
||||
"exists" => {
|
||||
if cmd.len() != 2 {
|
||||
return Err(DBError(format!("wrong number of arguments for EXISTS command")));
|
||||
}
|
||||
Cmd::Exists(cmd[1].clone())
|
||||
}
|
||||
"quit" => {
|
||||
if cmd.len() != 1 {
|
||||
return Err(DBError(format!("wrong number of arguments for QUIT command")));
|
||||
}
|
||||
Cmd::Quit
|
||||
}
|
||||
_ => Cmd::Unknow,
|
||||
},
|
||||
protocol.0,
|
||||
@@ -282,7 +342,11 @@ impl Cmd {
|
||||
Cmd::HLen(key) => hlen_cmd(server, key).await,
|
||||
Cmd::HMGet(key, fields) => hmget_cmd(server, key, fields).await,
|
||||
Cmd::HSetNx(key, field, value) => hsetnx_cmd(server, key, field, value).await,
|
||||
Cmd::HScan(key, cursor, pattern, count) => hscan_cmd(server, key, cursor, pattern.as_deref(), count).await,
|
||||
Cmd::Scan(cursor, pattern, count) => scan_cmd(server, cursor, pattern.as_deref(), count).await,
|
||||
Cmd::Ttl(key) => ttl_cmd(server, key).await,
|
||||
Cmd::Exists(key) => exists_cmd(server, key).await,
|
||||
Cmd::Quit => Ok(Protocol::SimpleString("OK".to_string())),
|
||||
Cmd::Unknow => Ok(Protocol::err("unknown cmd")),
|
||||
}
|
||||
}
|
||||
@@ -332,7 +396,11 @@ fn config_get_cmd(name: &String, server: &Server) -> Result<Protocol, DBError> {
|
||||
Protocol::BulkString(name.clone()),
|
||||
Protocol::BulkString("herodb.redb".to_string()),
|
||||
])),
|
||||
_ => Err(DBError(format!("unsupported config {:?}", name))),
|
||||
"databases" => Ok(Protocol::Array(vec![
|
||||
Protocol::BulkString(name.clone()),
|
||||
Protocol::BulkString("16".to_string()),
|
||||
])),
|
||||
_ => Ok(Protocol::Array(vec![])), // Return empty array for unknown configs instead of error
|
||||
}
|
||||
}
|
||||
|
||||
@@ -497,3 +565,31 @@ async fn scan_cmd(server: &Server, cursor: &u64, pattern: Option<&str>, count: &
|
||||
Err(e) => Ok(Protocol::err(&e.0)),
|
||||
}
|
||||
}
|
||||
|
||||
async fn hscan_cmd(server: &Server, key: &str, cursor: &u64, pattern: Option<&str>, count: &Option<u64>) -> Result<Protocol, DBError> {
|
||||
match server.storage.hscan(key, *cursor, pattern, *count) {
|
||||
Ok((next_cursor, fields)) => {
|
||||
let mut result = Vec::new();
|
||||
result.push(Protocol::BulkString(next_cursor.to_string()));
|
||||
result.push(Protocol::Array(
|
||||
fields.into_iter().map(Protocol::BulkString).collect(),
|
||||
));
|
||||
Ok(Protocol::Array(result))
|
||||
}
|
||||
Err(e) => Ok(Protocol::err(&e.0)),
|
||||
}
|
||||
}
|
||||
|
||||
async fn ttl_cmd(server: &Server, key: &str) -> Result<Protocol, DBError> {
|
||||
match server.storage.ttl(key) {
|
||||
Ok(ttl) => Ok(Protocol::SimpleString(ttl.to_string())),
|
||||
Err(e) => Ok(Protocol::err(&e.0)),
|
||||
}
|
||||
}
|
||||
|
||||
async fn exists_cmd(server: &Server, key: &str) -> Result<Protocol, DBError> {
|
||||
match server.storage.exists(key) {
|
||||
Ok(exists) => Ok(Protocol::SimpleString(if exists { "1" } else { "0" }.to_string())),
|
||||
Err(e) => Ok(Protocol::err(&e.0)),
|
||||
}
|
||||
}
|
||||
|
@@ -1,6 +1,6 @@
|
||||
mod cmd;
|
||||
pub mod cmd;
|
||||
pub mod error;
|
||||
pub mod options;
|
||||
mod protocol;
|
||||
pub mod protocol;
|
||||
pub mod server;
|
||||
mod storage;
|
||||
pub mod storage;
|
||||
|
@@ -50,6 +50,9 @@ impl Server {
|
||||
Cmd::from(s).unwrap_or((Cmd::Unknow, Protocol::err("unknow cmd")));
|
||||
println!("got command: {:?}, protocol: {:?}", cmd, protocol);
|
||||
|
||||
// Check if this is a QUIT command before processing
|
||||
let is_quit = matches!(cmd, Cmd::Quit);
|
||||
|
||||
let res = cmd
|
||||
.run(self, protocol, &mut queued_cmd)
|
||||
.await
|
||||
@@ -58,6 +61,12 @@ impl Server {
|
||||
|
||||
println!("going to send response {}", res.encode());
|
||||
_ = stream.write(res.encode().as_bytes()).await?;
|
||||
|
||||
// If this was a QUIT command, close the connection
|
||||
if is_quit {
|
||||
println!("[handle] QUIT command received, closing connection");
|
||||
return Ok(());
|
||||
}
|
||||
} else {
|
||||
println!("[handle] going to break");
|
||||
break;
|
||||
|
129
src/storage.rs
129
src/storage.rs
@@ -506,4 +506,133 @@ impl Storage {
|
||||
|
||||
Ok((next_cursor, keys))
|
||||
}
|
||||
|
||||
pub fn hscan(&self, key: &str, cursor: u64, pattern: Option<&str>, count: Option<u64>) -> Result<(u64, Vec<String>), DBError> {
|
||||
let read_txn = self.db.begin_read()?;
|
||||
|
||||
// Check if key exists and is a hash
|
||||
let types_table = read_txn.open_table(TYPES_TABLE)?;
|
||||
match types_table.get(key)? {
|
||||
Some(type_val) if type_val.value() == "hash" => {
|
||||
let hashes_table = read_txn.open_table(HASHES_TABLE)?;
|
||||
let count = count.unwrap_or(10);
|
||||
let mut fields = Vec::new();
|
||||
let mut current_cursor = 0u64;
|
||||
let mut returned_fields = 0u64;
|
||||
|
||||
let mut iter = hashes_table.iter()?;
|
||||
while let Some(entry) = iter.next() {
|
||||
let entry = entry?;
|
||||
let (hash_key, field) = entry.0.value();
|
||||
let value = entry.1.value();
|
||||
|
||||
if hash_key != key {
|
||||
continue;
|
||||
}
|
||||
|
||||
// Skip fields until we reach the cursor position
|
||||
if current_cursor < cursor {
|
||||
current_cursor += 1;
|
||||
continue;
|
||||
}
|
||||
|
||||
// Check if field matches pattern
|
||||
let matches = match pattern {
|
||||
Some(pat) => {
|
||||
if pat == "*" {
|
||||
true
|
||||
} else if pat.contains('*') {
|
||||
let pattern_parts: Vec<&str> = pat.split('*').collect();
|
||||
if pattern_parts.len() == 2 {
|
||||
let prefix = pattern_parts[0];
|
||||
let suffix = pattern_parts[1];
|
||||
field.starts_with(prefix) && field.ends_with(suffix)
|
||||
} else {
|
||||
field.contains(&pat.replace('*', ""))
|
||||
}
|
||||
} else {
|
||||
field.contains(pat)
|
||||
}
|
||||
}
|
||||
None => true,
|
||||
};
|
||||
|
||||
if matches {
|
||||
fields.push(field.to_string());
|
||||
fields.push(value.to_string());
|
||||
returned_fields += 1;
|
||||
|
||||
if returned_fields >= count {
|
||||
current_cursor += 1;
|
||||
break;
|
||||
}
|
||||
}
|
||||
|
||||
current_cursor += 1;
|
||||
}
|
||||
|
||||
let next_cursor = if returned_fields < count { 0 } else { current_cursor };
|
||||
Ok((next_cursor, fields))
|
||||
}
|
||||
Some(_) => Err(DBError("WRONGTYPE Operation against a key holding the wrong kind of value".to_string())),
|
||||
None => Ok((0, Vec::new())),
|
||||
}
|
||||
}
|
||||
|
||||
pub fn ttl(&self, key: &str) -> Result<i64, DBError> {
|
||||
let read_txn = self.db.begin_read()?;
|
||||
|
||||
// Check if key exists
|
||||
let types_table = read_txn.open_table(TYPES_TABLE)?;
|
||||
match types_table.get(key)? {
|
||||
Some(type_val) if type_val.value() == "string" => {
|
||||
let strings_table = read_txn.open_table(STRINGS_TABLE)?;
|
||||
match strings_table.get(key)? {
|
||||
Some(data) => {
|
||||
let string_value: StringValue = bincode::deserialize(data.value())?;
|
||||
match string_value.expires_at_ms {
|
||||
Some(expires_at) => {
|
||||
let now = now_in_millis();
|
||||
if now > expires_at {
|
||||
Ok(-2) // Key expired
|
||||
} else {
|
||||
Ok(((expires_at - now) / 1000) as i64) // TTL in seconds
|
||||
}
|
||||
}
|
||||
None => Ok(-1), // No expiration
|
||||
}
|
||||
}
|
||||
None => Ok(-2), // Key doesn't exist
|
||||
}
|
||||
}
|
||||
Some(_) => Ok(-1), // Other types don't have TTL implemented yet
|
||||
None => Ok(-2), // Key doesn't exist
|
||||
}
|
||||
}
|
||||
|
||||
pub fn exists(&self, key: &str) -> Result<bool, DBError> {
|
||||
let read_txn = self.db.begin_read()?;
|
||||
let types_table = read_txn.open_table(TYPES_TABLE)?;
|
||||
|
||||
match types_table.get(key)? {
|
||||
Some(_) => {
|
||||
// For string types, check if not expired
|
||||
if let Some(type_val) = types_table.get(key)? {
|
||||
if type_val.value() == "string" {
|
||||
let strings_table = read_txn.open_table(STRINGS_TABLE)?;
|
||||
if let Some(data) = strings_table.get(key)? {
|
||||
let string_value: StringValue = bincode::deserialize(data.value())?;
|
||||
if let Some(expires_at) = string_value.expires_at_ms {
|
||||
if now_in_millis() > expires_at {
|
||||
return Ok(false); // Expired
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
Ok(true)
|
||||
}
|
||||
None => Ok(false),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
Reference in New Issue
Block a user