diff --git a/herodb/src/cmd.rs b/herodb/src/cmd.rs index c036b4a..eef93c4 100644 --- a/herodb/src/cmd.rs +++ b/herodb/src/cmd.rs @@ -1,5 +1,7 @@ use crate::{error::DBError, protocol::Protocol, server::Server}; use serde::Serialize; +use tokio::time::{timeout, Duration}; +use futures::future::select_all; #[derive(Debug, Clone)] pub enum Cmd { @@ -43,6 +45,7 @@ pub enum Cmd { RPush(String, Vec), LPop(String, Option), RPop(String, Option), + BLPop(Vec, f64), LLen(String), LRem(String, i64, String), LTrim(String, i64, i64), @@ -376,6 +379,17 @@ impl Cmd { }; Cmd::RPop(cmd[1].clone(), count) } + "blpop" => { + if cmd.len() < 3 { + return Err(DBError(format!("wrong number of arguments for BLPOP command"))); + } + // keys are all but the last argument + let keys = cmd[1..cmd.len()-1].to_vec(); + let timeout_f = cmd[cmd.len()-1] + .parse::() + .map_err(|_| DBError("ERR timeout is not a number".to_string()))?; + Cmd::BLPop(keys, timeout_f) + } "llen" => { if cmd.len() != 2 { return Err(DBError(format!("wrong number of arguments for LLEN command"))); @@ -531,6 +545,7 @@ impl Cmd { Cmd::RPush(key, elements) => rpush_cmd(server, &key, &elements).await, Cmd::LPop(key, count) => lpop_cmd(server, &key, &count).await, Cmd::RPop(key, count) => rpop_cmd(server, &key, &count).await, + Cmd::BLPop(keys, timeout) => blpop_cmd(server, &keys, timeout).await, Cmd::LLen(key) => llen_cmd(server, &key).await, Cmd::LRem(key, count, element) => lrem_cmd(server, &key, count, &element).await, Cmd::LTrim(key, start, stop) => ltrim_cmd(server, &key, start, stop).await, @@ -661,16 +676,106 @@ async fn rpop_cmd(server: &Server, key: &str, count: &Option) -> Result Result { + // Immediate, non-blocking attempt in key order + for k in keys { + let elems = server.current_storage()?.lpop(k, 1)?; + if !elems.is_empty() { + return Ok(Protocol::Array(vec![ + Protocol::BulkString(k.clone()), + Protocol::BulkString(elems[0].clone()), + ])); + } + } + + // If timeout is zero, return immediately with Null + if timeout_secs <= 0.0 { + return Ok(Protocol::Null); + } + + // Register waiters for each key + let db_index = server.selected_db; + let mut ids: Vec = Vec::with_capacity(keys.len()); + let mut names: Vec = Vec::with_capacity(keys.len()); + let mut rxs: Vec> = Vec::with_capacity(keys.len()); + + for k in keys { + let (id, rx) = server.register_waiter(db_index, k).await; + ids.push(id); + names.push(k.clone()); + rxs.push(rx); + } + + // Wait for the first delivery or timeout + let wait_fut = async move { + let mut futures_vec = rxs; + loop { + if futures_vec.is_empty() { + return None; + } + let (res, idx, remaining) = select_all(futures_vec).await; + match res { + Ok((k, elem)) => { + return Some((k, elem, idx, remaining)); + } + Err(_canceled) => { + // That waiter was canceled; continue with the rest + futures_vec = remaining; + continue; + } + } + } + }; + + match timeout(Duration::from_secs_f64(timeout_secs), wait_fut).await { + Ok(Some((k, elem, idx, _remaining))) => { + // Unregister other waiters + for (i, key_name) in names.iter().enumerate() { + if i != idx { + server.unregister_waiter(db_index, key_name, ids[i]).await; + } + } + Ok(Protocol::Array(vec![ + Protocol::BulkString(k), + Protocol::BulkString(elem), + ])) + } + Ok(None) => { + // No futures left; unregister all waiters + for (i, key_name) in names.iter().enumerate() { + server.unregister_waiter(db_index, key_name, ids[i]).await; + } + Ok(Protocol::Null) + } + Err(_elapsed) => { + // Timeout: unregister all waiters + for (i, key_name) in names.iter().enumerate() { + server.unregister_waiter(db_index, key_name, ids[i]).await; + } + Ok(Protocol::Null) + } + } +} + async fn lpush_cmd(server: &Server, key: &str, elements: &[String]) -> Result { match server.current_storage()?.lpush(key, elements.to_vec()) { - Ok(len) => Ok(Protocol::SimpleString(len.to_string())), + Ok(len) => { + // Attempt to deliver to any blocked BLPOP waiters + let _ = server.drain_waiters_after_push(key).await; + Ok(Protocol::SimpleString(len.to_string())) + } Err(e) => Ok(Protocol::err(&e.0)), } } async fn rpush_cmd(server: &Server, key: &str, elements: &[String]) -> Result { match server.current_storage()?.rpush(key, elements.to_vec()) { - Ok(len) => Ok(Protocol::SimpleString(len.to_string())), + Ok(len) => { + // Attempt to deliver to any blocked BLPOP waiters + let _ = server.drain_waiters_after_push(key).await; + Ok(Protocol::SimpleString(len.to_string())) + } Err(e) => Ok(Protocol::err(&e.0)), } } diff --git a/herodb/src/server.rs b/herodb/src/server.rs index c286e21..68a0219 100644 --- a/herodb/src/server.rs +++ b/herodb/src/server.rs @@ -3,6 +3,9 @@ use std::collections::HashMap; use std::sync::Arc; use tokio::io::AsyncReadExt; use tokio::io::AsyncWriteExt; +use tokio::sync::{Mutex, oneshot}; + +use std::sync::atomic::{AtomicU64, Ordering}; use crate::cmd::Cmd; use crate::error::DBError; @@ -17,6 +20,15 @@ pub struct Server { pub client_name: Option, pub selected_db: u64, // Changed from usize to u64 pub queued_cmd: Option>, + + // BLPOP waiter registry: per (db_index, key) FIFO of waiters + pub list_waiters: Arc>>>>, + pub waiter_seq: Arc, +} + +pub struct Waiter { + pub id: u64, + pub tx: oneshot::Sender<(String, String)>, // (key, element) } impl Server { @@ -27,6 +39,9 @@ impl Server { client_name: None, selected_db: 0, queued_cmd: None, + + list_waiters: Arc::new(Mutex::new(HashMap::new())), + waiter_seq: Arc::new(AtomicU64::new(1)), } } @@ -66,6 +81,81 @@ impl Server { self.option.encrypt && db_index >= 10 } + // ----- BLPOP waiter helpers ----- + + pub async fn register_waiter(&self, db_index: u64, key: &str) -> (u64, oneshot::Receiver<(String, String)>) { + let id = self.waiter_seq.fetch_add(1, Ordering::Relaxed); + let (tx, rx) = oneshot::channel::<(String, String)>(); + + let mut guard = self.list_waiters.lock().await; + let per_db = guard.entry(db_index).or_insert_with(HashMap::new); + let q = per_db.entry(key.to_string()).or_insert_with(Vec::new); + q.push(Waiter { id, tx }); + (id, rx) + } + + pub async fn unregister_waiter(&self, db_index: u64, key: &str, id: u64) { + let mut guard = self.list_waiters.lock().await; + if let Some(per_db) = guard.get_mut(&db_index) { + if let Some(q) = per_db.get_mut(key) { + q.retain(|w| w.id != id); + if q.is_empty() { + per_db.remove(key); + } + } + if per_db.is_empty() { + guard.remove(&db_index); + } + } + } + + // Called after LPUSH/RPUSH to deliver to blocked BLPOP waiters. + pub async fn drain_waiters_after_push(&self, key: &str) -> Result<(), DBError> { + let db_index = self.selected_db; + + loop { + // Check if any waiter exists + let maybe_waiter = { + let mut guard = self.list_waiters.lock().await; + if let Some(per_db) = guard.get_mut(&db_index) { + if let Some(q) = per_db.get_mut(key) { + if !q.is_empty() { + // Pop FIFO + Some(q.remove(0)) + } else { + None + } + } else { + None + } + } else { + None + } + }; + + let waiter = if let Some(w) = maybe_waiter { w } else { break }; + + // Pop one element from the left + let elems = self.current_storage()?.lpop(key, 1)?; + if elems.is_empty() { + // Nothing to deliver; re-register waiter at the front to preserve order + let mut guard = self.list_waiters.lock().await; + let per_db = guard.entry(db_index).or_insert_with(HashMap::new); + let q = per_db.entry(key.to_string()).or_insert_with(Vec::new); + q.insert(0, waiter); + break; + } else { + let elem = elems[0].clone(); + // Send to waiter; if receiver dropped, just continue + let _ = waiter.tx.send((key.to_string(), elem)); + // Loop to try to satisfy more waiters if more elements remain + continue; + } + } + + Ok(()) + } + pub async fn handle( &mut self, mut stream: tokio::net::TcpStream,