diff --git a/herodb/src/cmd.rs b/herodb/src/cmd.rs index 59021e7..8a9058c 100644 --- a/herodb/src/cmd.rs +++ b/herodb/src/cmd.rs @@ -12,6 +12,8 @@ pub enum Cmd { Set(String, String), SetPx(String, String, u128), SetEx(String, String, u128), + // Advanced SET with options: (key, value, ex_ms, nx, xx, get) + SetOpts(String, String, Option, bool, bool, bool), MGet(Vec), MSet(Vec<(String, String)>), Keys, @@ -34,6 +36,8 @@ pub enum Cmd { HLen(String), HMGet(String, Vec), HSetNx(String, String, String), + HIncrBy(String, String, i64), + HIncrByFloat(String, String, f64), HScan(String, u64, Option, Option), // key, cursor, pattern, count Scan(u64, Option, Option), // cursor, pattern, count Ttl(String), @@ -101,14 +105,51 @@ impl Cmd { "ping" => Cmd::Ping, "get" => Cmd::Get(cmd[1].clone()), "set" => { - if cmd.len() == 5 && cmd[3].to_lowercase() == "px" { - Cmd::SetPx(cmd[1].clone(), cmd[2].clone(), cmd[4].parse().unwrap()) - } else if cmd.len() == 5 && cmd[3].to_lowercase() == "ex" { - Cmd::SetEx(cmd[1].clone(), cmd[2].clone(), cmd[4].parse().unwrap()) - } else if cmd.len() == 3 { - Cmd::Set(cmd[1].clone(), cmd[2].clone()) + if cmd.len() < 3 { + return Err(DBError("wrong number of arguments for SET".to_string())); + } + let key = cmd[1].clone(); + let val = cmd[2].clone(); + + // Parse optional flags: EX sec | PX ms | NX | XX | GET + let mut ex_ms: Option = None; + let mut nx = false; + let mut xx = false; + let mut getflag = false; + + let mut i = 3; + while i < cmd.len() { + match cmd[i].to_lowercase().as_str() { + "ex" => { + if i + 1 >= cmd.len() { + return Err(DBError("ERR syntax error".to_string())); + } + let secs: u128 = cmd[i + 1].parse().map_err(|_| DBError("ERR value is not an integer or out of range".to_string()))?; + ex_ms = Some(secs * 1000); + i += 2; + } + "px" => { + if i + 1 >= cmd.len() { + return Err(DBError("ERR syntax error".to_string())); + } + let ms: u128 = cmd[i + 1].parse().map_err(|_| DBError("ERR value is not an integer or out of range".to_string()))?; + ex_ms = Some(ms); + i += 2; + } + "nx" => { nx = true; i += 1; } + "xx" => { xx = true; i += 1; } + "get" => { getflag = true; i += 1; } + _ => { + return Err(DBError(format!("unsupported cmd {:?}", cmd))); + } + } + } + + // If no options, keep legacy behavior + if ex_ms.is_none() && !nx && !xx && !getflag { + Cmd::Set(key, val) } else { - return Err(DBError(format!("unsupported cmd {:?}", cmd))); + Cmd::SetOpts(key, val, ex_ms, nx, xx, getflag) } } "setex" => { @@ -259,6 +300,20 @@ impl Cmd { } Cmd::HSetNx(cmd[1].clone(), cmd[2].clone(), cmd[3].clone()) } + "hincrby" => { + if cmd.len() != 4 { + return Err(DBError(format!("wrong number of arguments for HINCRBY command"))); + } + let delta = cmd[3].parse::().map_err(|_| DBError("ERR value is not an integer or out of range".to_string()))?; + Cmd::HIncrBy(cmd[1].clone(), cmd[2].clone(), delta) + } + "hincrbyfloat" => { + if cmd.len() != 4 { + return Err(DBError(format!("wrong number of arguments for HINCRBYFLOAT command"))); + } + let delta = cmd[3].parse::().map_err(|_| DBError("ERR value is not a valid float".to_string()))?; + Cmd::HIncrByFloat(cmd[1].clone(), cmd[2].clone(), delta) + } "hscan" => { if cmd.len() < 3 { return Err(DBError(format!("wrong number of arguments for HSCAN command"))); @@ -560,6 +615,7 @@ impl Cmd { Cmd::Set(k, v) => set_cmd(server, &k, &v).await, Cmd::SetPx(k, v, x) => set_px_cmd(server, &k, &v, &x).await, Cmd::SetEx(k, v, x) => set_ex_cmd(server, &k, &v, &x).await, + Cmd::SetOpts(k, v, ex_ms, nx, xx, getflag) => set_with_opts_cmd(server, &k, &v, ex_ms, nx, xx, getflag).await, Cmd::MGet(keys) => mget_cmd(server, &keys).await, Cmd::MSet(pairs) => mset_cmd(server, &pairs).await, Cmd::Del(k) => del_cmd(server, &k).await, @@ -593,6 +649,8 @@ impl Cmd { Cmd::HLen(key) => hlen_cmd(server, &key).await, Cmd::HMGet(key, fields) => hmget_cmd(server, &key, &fields).await, Cmd::HSetNx(key, field, value) => hsetnx_cmd(server, &key, &field, &value).await, + Cmd::HIncrBy(key, field, delta) => hincrby_cmd(server, &key, &field, delta).await, + Cmd::HIncrByFloat(key, field, delta) => hincrbyfloat_cmd(server, &key, &field, delta).await, Cmd::HScan(key, cursor, pattern, count) => hscan_cmd(server, &key, &cursor, pattern.as_deref(), &count).await, Cmd::Scan(cursor, pattern, count) => scan_cmd(server, &cursor, pattern.as_deref(), &count).await, Cmd::Ttl(key) => ttl_cmd(server, &key).await, @@ -981,6 +1039,62 @@ async fn set_cmd(server: &Server, k: &str, v: &str) -> Result Ok(Protocol::SimpleString("OK".to_string())) } +// Advanced SET with options: EX/PX/NX/XX/GET +async fn set_with_opts_cmd( + server: &Server, + key: &str, + value: &str, + ex_ms: Option, + nx: bool, + xx: bool, + get_old: bool, +) -> Result { + let storage = server.current_storage()?; + + // Determine existence (for NX/XX) + let exists = storage.exists(key)?; + + // If both NX and XX, condition can never be satisfied -> no-op + let mut should_set = true; + if nx && exists { + should_set = false; + } + if xx && !exists { + should_set = false; + } + + // Fetch old value if needed for GET + let old_val = if get_old { + storage.get(key)? + } else { + None + }; + + if should_set { + if let Some(ms) = ex_ms { + storage.setx(key.to_string(), value.to_string(), ms)?; + } else { + storage.set(key.to_string(), value.to_string())?; + } + } + + if get_old { + // Return previous value (or Null), regardless of NX/XX outcome only if set executed? + // We follow Redis semantics: return old value if set executed, else Null + if should_set { + Ok(old_val.map_or(Protocol::Null, Protocol::BulkString)) + } else { + Ok(Protocol::Null) + } + } else { + if should_set { + Ok(Protocol::SimpleString("OK".to_string())) + } else { + Ok(Protocol::Null) + } + } +} + // MGET: return array of bulk strings or Null for missing async fn mget_cmd(server: &Server, keys: &[String]) -> Result { let mut out: Vec = Vec::with_capacity(keys.len()); @@ -1120,6 +1234,32 @@ async fn hsetnx_cmd(server: &Server, key: &str, field: &str, value: &str) -> Res } } +async fn hincrby_cmd(server: &Server, key: &str, field: &str, delta: i64) -> Result { + let storage = server.current_storage()?; + let current = storage.hget(key, field)?; + let base: i64 = match current { + Some(v) => v.parse::().map_err(|_| DBError("ERR value is not an integer or out of range".to_string()))?, + None => 0, + }; + let new_val = base.checked_add(delta).ok_or_else(|| DBError("ERR increment or decrement would overflow".to_string()))?; + // Update the field + storage.hset(key, vec![(field.to_string(), new_val.to_string())])?; + Ok(Protocol::SimpleString(new_val.to_string())) +} + +async fn hincrbyfloat_cmd(server: &Server, key: &str, field: &str, delta: f64) -> Result { + let storage = server.current_storage()?; + let current = storage.hget(key, field)?; + let base: f64 = match current { + Some(v) => v.parse::().map_err(|_| DBError("ERR value is not a valid float".to_string()))?, + None => 0.0, + }; + let new_val = base + delta; + // Update the field + storage.hset(key, vec![(field.to_string(), new_val.to_string())])?; + Ok(Protocol::SimpleString(new_val.to_string())) +} + async fn scan_cmd( server: &Server, cursor: &u64, diff --git a/herodb/src/protocol.rs b/herodb/src/protocol.rs index c9a2255..6025074 100644 --- a/herodb/src/protocol.rs +++ b/herodb/src/protocol.rs @@ -19,6 +19,10 @@ impl fmt::Display for Protocol { impl Protocol { pub fn from(protocol: &str) -> Result<(Self, &str), DBError> { + if protocol.is_empty() { + // Incomplete frame; caller should read more bytes + return Err(DBError("[incomplete] empty".to_string())); + } let ret = match protocol.chars().nth(0) { Some('+') => Self::parse_simple_string_sfx(&protocol[1..]), Some('$') => Self::parse_bulk_string_sfx(&protocol[1..]), @@ -101,21 +105,20 @@ impl Protocol { let size = Self::parse_usize(&protocol[..len_end])?; let data_start = len_end + 2; let data_end = data_start + size; - let s = Self::parse_string(&protocol[data_start..data_end])?; - if protocol.len() < data_end + 2 || &protocol[data_end..data_end+2] != "\r\n" { - Err(DBError(format!( - "[new bulk string] unmatched string length in prototocl {:?}", - protocol, - ))) - } else { - Ok((Protocol::BulkString(s), &protocol[data_end + 2..])) + // If we don't yet have the full bulk payload + trailing CRLF, signal INCOMPLETE + if protocol.len() < data_end + 2 { + return Err(DBError("[incomplete] bulk body".to_string())); } + if &protocol[data_end..data_end + 2] != "\r\n" { + return Err(DBError("[incomplete] bulk terminator".to_string())); + } + + let s = Self::parse_string(&protocol[data_start..data_end])?; + Ok((Protocol::BulkString(s), &protocol[data_end + 2..])) } else { - Err(DBError(format!( - "[new bulk string] unsupported protocol: {:?}", - protocol - ))) + // No CRLF after bulk length header yet + Err(DBError("[incomplete] bulk header".to_string())) } } @@ -125,16 +128,25 @@ impl Protocol { let mut remaining = &s[len_end + 2..]; let mut vec = vec![]; for _ in 0..array_len { - let (p, rem) = Protocol::from(remaining)?; - vec.push(p); - remaining = rem; + match Protocol::from(remaining) { + Ok((p, rem)) => { + vec.push(p); + remaining = rem; + } + Err(e) => { + // Propagate incomplete so caller can read more bytes + if e.0.starts_with("[incomplete]") { + return Err(e); + } else { + return Err(e); + } + } + } } Ok((Protocol::Array(vec), remaining)) } else { - Err(DBError(format!( - "[new array] unsupported protocol: {:?}", - s - ))) + // No CRLF after array header yet + Err(DBError("[incomplete] array header".to_string())) } } diff --git a/herodb/src/server.rs b/herodb/src/server.rs index 68a0219..23a93af 100644 --- a/herodb/src/server.rs +++ b/herodb/src/server.rs @@ -160,31 +160,40 @@ impl Server { &mut self, mut stream: tokio::net::TcpStream, ) -> Result<(), DBError> { - let mut buf = [0; 512]; - + // Accumulate incoming bytes to handle partial RESP frames + let mut acc = String::new(); + let mut buf = vec![0u8; 8192]; + loop { - let len = match stream.read(&mut buf).await { + let n = match stream.read(&mut buf).await { Ok(0) => { println!("[handle] connection closed"); return Ok(()); } - Ok(len) => len, + Ok(n) => n, Err(e) => { println!("[handle] read error: {:?}", e); return Err(e.into()); } }; - let mut s = str::from_utf8(&buf[..len])?; - while !s.is_empty() { - let (cmd, protocol, remaining) = match Cmd::from(s) { + // Append to accumulator. RESP for our usage is ASCII-safe. + acc.push_str(str::from_utf8(&buf[..n])?); + + // Try to parse as many complete commands as are available in 'acc'. + loop { + let parsed = Cmd::from(&acc); + let (cmd, protocol, remaining) = match parsed { Ok((cmd, protocol, remaining)) => (cmd, protocol, remaining), - Err(e) => { - println!("\x1b[31;1mprotocol error: {:?}\x1b[0m", e); - (Cmd::Unknow("protocol_error".to_string()), Protocol::err(&format!("protocol error: {}", e.0)), "") + Err(_e) => { + // Incomplete or invalid frame; assume incomplete and wait for more data. + // This avoids emitting spurious protocol_error for split frames. + break; } }; - s = remaining; + + // Advance the accumulator to the unparsed remainder + acc = remaining.to_string(); if self.option.debug { println!("\x1b[34;1mgot command: {:?}, protocol: {:?}\x1b[0m", cmd, protocol); @@ -204,7 +213,7 @@ impl Server { Protocol::err(&format!("ERR {}", e.0)) } }; - + if self.option.debug { println!("\x1b[34;1mqueued cmd {:?}\x1b[0m", self.queued_cmd); println!("\x1b[32;1mgoing to send response {}\x1b[0m", res.encode()); @@ -220,6 +229,11 @@ impl Server { println!("[handle] QUIT command received, closing connection"); return Ok(()); } + + // Continue parsing any further complete commands already in 'acc' + if acc.is_empty() { + break; + } } } } diff --git a/herodb/tests/usage_suite.rs b/herodb/tests/usage_suite.rs index 5ec554d..c61fecf 100644 --- a/herodb/tests/usage_suite.rs +++ b/herodb/tests/usage_suite.rs @@ -70,14 +70,107 @@ async fn connect(port: u16) -> TcpStream { } } +fn find_crlf(buf: &[u8], start: usize) -> Option { + let mut i = start; + while i + 1 < buf.len() { + if buf[i] == b'\r' && buf[i + 1] == b'\n' { + return Some(i); + } + i += 1; + } + None +} + +fn parse_number_i64(buf: &[u8], start: usize, end: usize) -> Option { + let s = std::str::from_utf8(&buf[start..end]).ok()?; + s.parse::().ok() +} + +// Return number of bytes that make up a complete RESP element starting at 'i', or None if incomplete. +fn parse_elem(buf: &[u8], i: usize) -> Option { + if i >= buf.len() { + return None; + } + match buf[i] { + b'+' | b'-' | b':' => { + let end = find_crlf(buf, i + 1)?; + Some(end + 2 - i) + } + b'$' => { + let hdr_end = find_crlf(buf, i + 1)?; + let n = parse_number_i64(buf, i + 1, hdr_end)?; + if n < 0 { + // Null bulk string: only header + Some(hdr_end + 2 - i) + } else { + let need = hdr_end + 2 + (n as usize) + 2; + if need <= buf.len() { + Some(need - i) + } else { + None + } + } + } + b'*' => { + let hdr_end = find_crlf(buf, i + 1)?; + let n = parse_number_i64(buf, i + 1, hdr_end)?; + if n < 0 { + // Null array: only header + Some(hdr_end + 2 - i) + } else { + let mut j = hdr_end + 2; + for _ in 0..(n as usize) { + let consumed = parse_elem(buf, j)?; + j += consumed; + } + Some(j - i) + } + } + _ => None, + } +} + +fn resp_frame_len(buf: &[u8]) -> Option { + parse_elem(buf, 0) +} + +async fn read_full_resp(stream: &mut TcpStream) -> String { + let mut buf: Vec = Vec::with_capacity(8192); + let mut tmp = vec![0u8; 4096]; + + loop { + if let Some(total) = resp_frame_len(&buf) { + if buf.len() >= total { + return String::from_utf8_lossy(&buf[..total]).to_string(); + } + } + + match tokio::time::timeout(Duration::from_secs(2), stream.read(&mut tmp)).await { + Ok(Ok(n)) => { + if n == 0 { + if let Some(total) = resp_frame_len(&buf) { + if buf.len() >= total { + return String::from_utf8_lossy(&buf[..total]).to_string(); + } + } + return String::from_utf8_lossy(&buf).to_string(); + } + buf.extend_from_slice(&tmp[..n]); + } + Ok(Err(e)) => panic!("read error: {}", e), + Err(_) => panic!("timeout waiting for reply"), + } + + if buf.len() > 8 * 1024 * 1024 { + panic!("reply too large"); + } + } +} + async fn send_cmd(stream: &mut TcpStream, args: &[&str]) -> String { let req = build_resp(args); stream.write_all(req.as_bytes()).await.unwrap(); - - // Single read is enough for these small replies - let mut buf = vec![0u8; 8192]; - let n = stream.read(&mut buf).await.unwrap(); - String::from_utf8_lossy(&buf[..n]).to_string() + read_full_resp(stream).await } // Assert helpers with clearer output @@ -559,6 +652,58 @@ async fn test_10_expire_pexpire_persist() { assert_contains(&persist2, "0", "PERSIST again -> 0 (no expiration to remove)"); } +#[tokio::test] +async fn test_11_set_with_options() { + let (server, port) = start_test_server("set_opts").await; + spawn_listener(server, port).await; + sleep(Duration::from_millis(150)).await; + + let mut s = connect(port).await; + + // SET with GET on non-existing key -> returns Null, sets value + let set_get1 = send_cmd(&mut s, &["SET", "s1", "v1", "GET"]).await; + assert_contains(&set_get1, "$-1", "SET s1 v1 GET returns Null when key didn't exist"); + let g1 = send_cmd(&mut s, &["GET", "s1"]).await; + assert_contains(&g1, "v1", "GET s1 after first SET"); + + // SET with GET should return old value, then set to new + let set_get2 = send_cmd(&mut s, &["SET", "s1", "v2", "GET"]).await; + assert_contains(&set_get2, "v1", "SET s1 v2 GET returns previous value v1"); + let g2 = send_cmd(&mut s, &["GET", "s1"]).await; + assert_contains(&g2, "v2", "GET s1 now v2"); + + // NX prevents update when key exists; with GET should return Null and not change + let set_nx = send_cmd(&mut s, &["SET", "s1", "v3", "NX", "GET"]).await; + assert_contains(&set_nx, "$-1", "SET s1 v3 NX GET returns Null when not set"); + let g3 = send_cmd(&mut s, &["GET", "s1"]).await; + assert_contains(&g3, "v2", "GET s1 remains v2 after NX prevented write"); + + // NX allows set when key does not exist + let set_nx2 = send_cmd(&mut s, &["SET", "s2", "v10", "NX"]).await; + assert_contains(&set_nx2, "OK", "SET s2 v10 NX -> OK for new key"); + let g4 = send_cmd(&mut s, &["GET", "s2"]).await; + assert_contains(&g4, "v10", "GET s2 is v10"); + + // XX requires existing key; with GET returns old value and sets new + let set_xx = send_cmd(&mut s, &["SET", "s2", "v11", "XX", "GET"]).await; + assert_contains(&set_xx, "v10", "SET s2 v11 XX GET returns previous v10"); + let g5 = send_cmd(&mut s, &["GET", "s2"]).await; + assert_contains(&g5, "v11", "GET s2 is now v11"); + + // PX expiration path via SET options + let set_px = send_cmd(&mut s, &["SET", "s3", "vpx", "PX", "500"]).await; + assert_contains(&set_px, "OK", "SET s3 vpx PX 500 -> OK"); + let ttl_px1 = send_cmd(&mut s, &["TTL", "s3"]).await; + assert!( + ttl_px1.contains("0") || ttl_px1.contains("1"), + "TTL s3 immediately after PX should be 1 or 0, got: {}", + ttl_px1 + ); + sleep(Duration::from_millis(650)).await; + let g6 = send_cmd(&mut s, &["GET", "s3"]).await; + assert_contains(&g6, "$-1", "GET s3 after PX expiry -> Null"); +} + #[tokio::test] async fn test_09_mget_mset_and_variadic_exists_del() { let (server, port) = start_test_server("mget_mset_variadic").await; @@ -597,4 +742,41 @@ async fn test_09_mget_mset_and_variadic_exists_del() { assert_contains(&mget_after, "$-1", "MGET k1 after DEL -> Null"); assert_contains(&mget_after, "v2", "MGET k2 remains"); assert_contains(&mget_after, "$-1", "MGET k3 after DEL -> Null"); +} +#[tokio::test] +async fn test_12_hash_incr() { + let (server, port) = start_test_server("hash_incr").await; + spawn_listener(server, port).await; + sleep(Duration::from_millis(150)).await; + + let mut s = connect(port).await; + + // Integer increments + let _ = send_cmd(&mut s, &["HSET", "hinc", "a", "1"]).await; + let r1 = send_cmd(&mut s, &["HINCRBY", "hinc", "a", "2"]).await; + assert_contains(&r1, "3", "HINCRBY hinc a 2 -> 3"); + + let r2 = send_cmd(&mut s, &["HINCRBY", "hinc", "a", "-1"]).await; + assert_contains(&r2, "2", "HINCRBY hinc a -1 -> 2"); + + let r3 = send_cmd(&mut s, &["HINCRBY", "hinc", "b", "5"]).await; + assert_contains(&r3, "5", "HINCRBY hinc b 5 -> 5"); + + // HINCRBY error on non-integer field + let _ = send_cmd(&mut s, &["HSET", "hinc", "s", "x"]).await; + let r_err = send_cmd(&mut s, &["HINCRBY", "hinc", "s", "1"]).await; + assert_contains(&r_err, "ERR", "HINCRBY on non-integer field should ERR"); + + // Float increments + let r4 = send_cmd(&mut s, &["HINCRBYFLOAT", "hinc", "f", "1.5"]).await; + assert_contains(&r4, "1.5", "HINCRBYFLOAT hinc f 1.5 -> 1.5"); + + let r5 = send_cmd(&mut s, &["HINCRBYFLOAT", "hinc", "f", "2.5"]).await; + // Could be "4", "4.0", or "4.000000", accept "4" substring + assert_contains(&r5, "4", "HINCRBYFLOAT hinc f 2.5 -> 4"); + + // HINCRBYFLOAT error on non-float field + let _ = send_cmd(&mut s, &["HSET", "hinc", "notf", "abc"]).await; + let r6 = send_cmd(&mut s, &["HINCRBYFLOAT", "hinc", "notf", "1"]).await; + assert_contains(&r6, "ERR", "HINCRBYFLOAT on non-float field should ERR"); } \ No newline at end of file