This commit is contained in:
2025-08-16 09:06:33 +02:00
parent 0511dddd99
commit 5502ff4bc5
10 changed files with 151 additions and 225 deletions

View File

@@ -4,6 +4,7 @@ use crate::{error::DBError, protocol::Protocol, server::Server};
pub enum Cmd {
Ping,
Echo(String),
Select(u16),
Get(String),
Set(String, String),
SetPx(String, String, u128),
@@ -60,6 +61,13 @@ impl Cmd {
}
Ok((
match cmd[0].to_lowercase().as_str() {
"select" => {
if cmd.len() != 2 {
return Err(DBError("wrong number of arguments for SELECT".to_string()));
}
let idx = cmd[1].parse::<u16>().map_err(|_| DBError("ERR DB index is not an integer".to_string()))?;
Cmd::Select(idx)
}
"echo" => Cmd::Echo(cmd[1].clone()),
"ping" => Cmd::Ping,
"get" => Cmd::Get(cmd[1].clone()),
@@ -419,6 +427,7 @@ impl Cmd {
}
match self {
Cmd::Select(db) => select_cmd(server, *db).await,
Cmd::Ping => Ok(Protocol::SimpleString("PONG".to_string())),
Cmd::Echo(s) => Ok(Protocol::SimpleString(s.clone())),
Cmd::Get(k) => get_cmd(server, k).await,
@@ -480,9 +489,17 @@ impl Cmd {
}
}
}
async fn select_cmd(server: &mut Server, db: u16) -> Result<Protocol, DBError> {
let idx = db as usize;
if idx >= server.storages.len() {
return Ok(Protocol::err("ERR DB index is out of range"));
}
server.selected_db = idx;
Ok(Protocol::SimpleString("OK".to_string()))
}
async fn lindex_cmd(server: &Server, key: &str, index: i64) -> Result<Protocol, DBError> {
match server.storage.lindex(key, index) {
match server.current_storage().lindex(key, index) {
Ok(Some(element)) => Ok(Protocol::BulkString(element)),
Ok(None) => Ok(Protocol::Null),
Err(e) => Ok(Protocol::err(&e.0)),
@@ -490,35 +507,35 @@ async fn lindex_cmd(server: &Server, key: &str, index: i64) -> Result<Protocol,
}
async fn lrange_cmd(server: &Server, key: &str, start: i64, stop: i64) -> Result<Protocol, DBError> {
match server.storage.lrange(key, start, stop) {
match server.current_storage().lrange(key, start, stop) {
Ok(elements) => Ok(Protocol::Array(elements.into_iter().map(Protocol::BulkString).collect())),
Err(e) => Ok(Protocol::err(&e.0)),
}
}
async fn ltrim_cmd(server: &Server, key: &str, start: i64, stop: i64) -> Result<Protocol, DBError> {
match server.storage.ltrim(key, start, stop) {
match server.current_storage().ltrim(key, start, stop) {
Ok(_) => Ok(Protocol::SimpleString("OK".to_string())),
Err(e) => Ok(Protocol::err(&e.0)),
}
}
async fn lrem_cmd(server: &Server, key: &str, count: i64, element: &str) -> Result<Protocol, DBError> {
match server.storage.lrem(key, count, element) {
match server.current_storage().lrem(key, count, element) {
Ok(removed_count) => Ok(Protocol::SimpleString(removed_count.to_string())),
Err(e) => Ok(Protocol::err(&e.0)),
}
}
async fn llen_cmd(server: &Server, key: &str) -> Result<Protocol, DBError> {
match server.storage.llen(key) {
match server.current_storage().llen(key) {
Ok(len) => Ok(Protocol::SimpleString(len.to_string())),
Err(e) => Ok(Protocol::err(&e.0)),
}
}
async fn lpop_cmd(server: &Server, key: &str, count: &Option<u64>) -> Result<Protocol, DBError> {
match server.storage.lpop(key, *count) {
match server.current_storage().lpop(key, *count) {
Ok(Some(elements)) => {
if count.is_some() {
Ok(Protocol::Array(elements.into_iter().map(Protocol::BulkString).collect()))
@@ -538,7 +555,7 @@ async fn lpop_cmd(server: &Server, key: &str, count: &Option<u64>) -> Result<Pro
}
async fn rpop_cmd(server: &Server, key: &str, count: &Option<u64>) -> Result<Protocol, DBError> {
match server.storage.rpop(key, *count) {
match server.current_storage().rpop(key, *count) {
Ok(Some(elements)) => {
if count.is_some() {
Ok(Protocol::Array(elements.into_iter().map(Protocol::BulkString).collect()))
@@ -558,14 +575,14 @@ async fn rpop_cmd(server: &Server, key: &str, count: &Option<u64>) -> Result<Pro
}
async fn lpush_cmd(server: &Server, key: &str, elements: &[String]) -> Result<Protocol, DBError> {
match server.storage.lpush(key, elements.to_vec()) {
match server.current_storage().lpush(key, elements.to_vec()) {
Ok(len) => Ok(Protocol::SimpleString(len.to_string())),
Err(e) => Ok(Protocol::err(&e.0)),
}
}
async fn rpush_cmd(server: &Server, key: &str, elements: &[String]) -> Result<Protocol, DBError> {
match server.storage.rpush(key, elements.to_vec()) {
match server.current_storage().rpush(key, elements.to_vec()) {
Ok(len) => Ok(Protocol::SimpleString(len.to_string())),
Err(e) => Ok(Protocol::err(&e.0)),
}
@@ -589,7 +606,7 @@ async fn exec_cmd(
}
async fn incr_cmd(server: &Server, key: &String) -> Result<Protocol, DBError> {
let current_value = server.storage.get(key)?;
let current_value = server.current_storage().get(key)?;
let new_value = match current_value {
Some(v) => {
@@ -601,7 +618,7 @@ async fn incr_cmd(server: &Server, key: &String) -> Result<Protocol, DBError> {
None => 1,
};
server.storage.set(key.clone(), new_value.to_string())?;
server.current_storage().set(key.clone(), new_value.to_string())?;
Ok(Protocol::SimpleString(new_value.to_string()))
}
@@ -613,18 +630,18 @@ fn config_get_cmd(name: &String, server: &Server) -> Result<Protocol, DBError> {
])),
"dbfilename" => Ok(Protocol::Array(vec![
Protocol::BulkString(name.clone()),
Protocol::BulkString("herodb.redb".to_string()),
Protocol::BulkString(format!("{}.db", server.selected_db)),
])),
"databases" => Ok(Protocol::Array(vec![
Protocol::BulkString(name.clone()),
Protocol::BulkString("16".to_string()),
Protocol::BulkString(server.option.databases.to_string()),
])),
_ => Ok(Protocol::Array(vec![])), // Return empty array for unknown configs instead of error
_ => Ok(Protocol::Array(vec![])),
}
}
async fn keys_cmd(server: &Server) -> Result<Protocol, DBError> {
let keys = server.storage.keys("*")?;
let keys = server.current_storage().keys("*")?;
Ok(Protocol::Array(
keys.into_iter().map(Protocol::BulkString).collect(),
))
@@ -643,14 +660,14 @@ fn info_cmd(section: &Option<String>) -> Result<Protocol, DBError> {
}
async fn type_cmd(server: &Server, k: &String) -> Result<Protocol, DBError> {
match server.storage.get_key_type(k)? {
match server.current_storage().get_key_type(k)? {
Some(type_str) => Ok(Protocol::SimpleString(type_str)),
None => Ok(Protocol::SimpleString("none".to_string())),
}
}
async fn del_cmd(server: &Server, k: &str) -> Result<Protocol, DBError> {
server.storage.del(k.to_string())?;
server.current_storage().del(k.to_string())?;
Ok(Protocol::SimpleString("1".to_string()))
}
@@ -660,7 +677,7 @@ async fn set_ex_cmd(
v: &str,
x: &u128,
) -> Result<Protocol, DBError> {
server.storage.setx(k.to_string(), v.to_string(), *x * 1000)?;
server.current_storage().setx(k.to_string(), v.to_string(), *x * 1000)?;
Ok(Protocol::SimpleString("OK".to_string()))
}
@@ -670,28 +687,28 @@ async fn set_px_cmd(
v: &str,
x: &u128,
) -> Result<Protocol, DBError> {
server.storage.setx(k.to_string(), v.to_string(), *x)?;
server.current_storage().setx(k.to_string(), v.to_string(), *x)?;
Ok(Protocol::SimpleString("OK".to_string()))
}
async fn set_cmd(server: &Server, k: &str, v: &str) -> Result<Protocol, DBError> {
server.storage.set(k.to_string(), v.to_string())?;
server.current_storage().set(k.to_string(), v.to_string())?;
Ok(Protocol::SimpleString("OK".to_string()))
}
async fn get_cmd(server: &Server, k: &str) -> Result<Protocol, DBError> {
let v = server.storage.get(k)?;
let v = server.current_storage().get(k)?;
Ok(v.map_or(Protocol::Null, Protocol::BulkString))
}
// Hash command implementations
async fn hset_cmd(server: &Server, key: &str, pairs: &[(String, String)]) -> Result<Protocol, DBError> {
let new_fields = server.storage.hset(key, pairs)?;
let new_fields = server.current_storage().hset(key, pairs)?;
Ok(Protocol::SimpleString(new_fields.to_string()))
}
async fn hget_cmd(server: &Server, key: &str, field: &str) -> Result<Protocol, DBError> {
match server.storage.hget(key, field) {
match server.current_storage().hget(key, field) {
Ok(Some(value)) => Ok(Protocol::BulkString(value)),
Ok(None) => Ok(Protocol::Null),
Err(e) => Ok(Protocol::err(&e.0)),
@@ -699,7 +716,7 @@ async fn hget_cmd(server: &Server, key: &str, field: &str) -> Result<Protocol, D
}
async fn hgetall_cmd(server: &Server, key: &str) -> Result<Protocol, DBError> {
match server.storage.hgetall(key) {
match server.current_storage().hgetall(key) {
Ok(pairs) => {
let mut result = Vec::new();
for (field, value) in pairs {
@@ -713,21 +730,21 @@ async fn hgetall_cmd(server: &Server, key: &str) -> Result<Protocol, DBError> {
}
async fn hdel_cmd(server: &Server, key: &str, fields: &[String]) -> Result<Protocol, DBError> {
match server.storage.hdel(key, fields) {
match server.current_storage().hdel(key, fields) {
Ok(deleted) => Ok(Protocol::SimpleString(deleted.to_string())),
Err(e) => Ok(Protocol::err(&e.0)),
}
}
async fn hexists_cmd(server: &Server, key: &str, field: &str) -> Result<Protocol, DBError> {
match server.storage.hexists(key, field) {
match server.current_storage().hexists(key, field) {
Ok(exists) => Ok(Protocol::SimpleString(if exists { "1" } else { "0" }.to_string())),
Err(e) => Ok(Protocol::err(&e.0)),
}
}
async fn hkeys_cmd(server: &Server, key: &str) -> Result<Protocol, DBError> {
match server.storage.hkeys(key) {
match server.current_storage().hkeys(key) {
Ok(keys) => Ok(Protocol::Array(
keys.into_iter().map(Protocol::BulkString).collect(),
)),
@@ -736,7 +753,7 @@ async fn hkeys_cmd(server: &Server, key: &str) -> Result<Protocol, DBError> {
}
async fn hvals_cmd(server: &Server, key: &str) -> Result<Protocol, DBError> {
match server.storage.hvals(key) {
match server.current_storage().hvals(key) {
Ok(values) => Ok(Protocol::Array(
values.into_iter().map(Protocol::BulkString).collect(),
)),
@@ -745,14 +762,14 @@ async fn hvals_cmd(server: &Server, key: &str) -> Result<Protocol, DBError> {
}
async fn hlen_cmd(server: &Server, key: &str) -> Result<Protocol, DBError> {
match server.storage.hlen(key) {
match server.current_storage().hlen(key) {
Ok(len) => Ok(Protocol::SimpleString(len.to_string())),
Err(e) => Ok(Protocol::err(&e.0)),
}
}
async fn hmget_cmd(server: &Server, key: &str, fields: &[String]) -> Result<Protocol, DBError> {
match server.storage.hmget(key, fields) {
match server.current_storage().hmget(key, fields) {
Ok(values) => {
let result: Vec<Protocol> = values
.into_iter()
@@ -765,14 +782,14 @@ async fn hmget_cmd(server: &Server, key: &str, fields: &[String]) -> Result<Prot
}
async fn hsetnx_cmd(server: &Server, key: &str, field: &str, value: &str) -> Result<Protocol, DBError> {
match server.storage.hsetnx(key, field, value) {
match server.current_storage().hsetnx(key, field, value) {
Ok(was_set) => Ok(Protocol::SimpleString(if was_set { "1" } else { "0" }.to_string())),
Err(e) => Ok(Protocol::err(&e.0)),
}
}
async fn scan_cmd(server: &Server, cursor: &u64, pattern: Option<&str>, count: &Option<u64>) -> Result<Protocol, DBError> {
match server.storage.scan(*cursor, pattern, *count) {
match server.current_storage().scan(*cursor, pattern, *count) {
Ok((next_cursor, keys)) => {
let mut result = Vec::new();
result.push(Protocol::BulkString(next_cursor.to_string()));
@@ -786,7 +803,7 @@ async fn scan_cmd(server: &Server, cursor: &u64, pattern: Option<&str>, count: &
}
async fn hscan_cmd(server: &Server, key: &str, cursor: &u64, pattern: Option<&str>, count: &Option<u64>) -> Result<Protocol, DBError> {
match server.storage.hscan(key, *cursor, pattern, *count) {
match server.current_storage().hscan(key, *cursor, pattern, *count) {
Ok((next_cursor, fields)) => {
let mut result = Vec::new();
result.push(Protocol::BulkString(next_cursor.to_string()));
@@ -800,14 +817,14 @@ async fn hscan_cmd(server: &Server, key: &str, cursor: &u64, pattern: Option<&st
}
async fn ttl_cmd(server: &Server, key: &str) -> Result<Protocol, DBError> {
match server.storage.ttl(key) {
match server.current_storage().ttl(key) {
Ok(ttl) => Ok(Protocol::SimpleString(ttl.to_string())),
Err(e) => Ok(Protocol::err(&e.0)),
}
}
async fn exists_cmd(server: &Server, key: &str) -> Result<Protocol, DBError> {
match server.storage.exists(key) {
match server.current_storage().exists(key) {
Ok(exists) => Ok(Protocol::SimpleString(if exists { "1" } else { "0" }.to_string())),
Err(e) => Ok(Protocol::err(&e.0)),
}

View File

@@ -22,6 +22,10 @@ struct Args {
/// Enable debug mode
#[arg(long)]
debug: bool,
/// Number of logical databases (SELECT 0..N-1)
#[arg(long, default_value_t = 16)]
databases: u16,
}
#[tokio::main]
@@ -41,6 +45,7 @@ async fn main() {
dir: args.dir,
port,
debug: args.debug,
databases: args.databases,
};
// new server

View File

@@ -3,4 +3,5 @@ pub struct DBOption {
pub dir: String,
pub port: u16,
pub debug: bool,
pub databases: u16, // number of logical DBs (default 16)
}

View File

@@ -12,27 +12,36 @@ use crate::storage::Storage;
#[derive(Clone)]
pub struct Server {
pub storage: Arc<Storage>,
pub storages: Vec<Arc<Storage>>,
pub option: options::DBOption,
pub client_name: Option<String>,
pub selected_db: usize, // per-connection
}
impl Server {
pub async fn new(option: options::DBOption) -> Self {
// Create database file path with fixed filename
let db_file_path = PathBuf::from(option.dir.clone()).join("herodb.redb");
println!("will open db file path: {}", db_file_path.display());
// Initialize storage with redb
let storage = Storage::new(db_file_path).expect("Failed to initialize storage");
// Eagerly create N db files: <dir>/<index>.db
let mut storages = Vec::with_capacity(option.databases as usize);
for i in 0..option.databases {
let db_file_path = PathBuf::from(option.dir.clone()).join(format!("{}.db", i));
println!("will open db file path (db {}): {}", i, db_file_path.display());
let storage = Storage::new(db_file_path).expect("Failed to initialize storage");
storages.push(Arc::new(storage));
}
Server {
storage: Arc::new(storage),
storages,
option,
client_name: None,
selected_db: 0,
}
}
#[inline]
pub fn current_storage(&self) -> &Storage {
self.storages[self.selected_db].as_ref()
}
pub async fn handle(
&mut self,
mut stream: tokio::net::TcpStream,