implemented HINCRBY/HINCRBYFLOAT + fixed partial-frame handling bug causing sporadic protocol parsing errors when sending or receiving large bulk strings (AGE ciphertext/signature) (happend because TCP segmentation can split a single RESP frame; both client & server assumed a single read would contain the whole frame)
This commit is contained in:
		| @@ -12,6 +12,8 @@ pub enum Cmd { | ||||
|     Set(String, String), | ||||
|     SetPx(String, String, u128), | ||||
|     SetEx(String, String, u128), | ||||
|     // Advanced SET with options: (key, value, ex_ms, nx, xx, get) | ||||
|     SetOpts(String, String, Option<u128>, bool, bool, bool), | ||||
|     MGet(Vec<String>), | ||||
|     MSet(Vec<(String, String)>), | ||||
|     Keys, | ||||
| @@ -34,6 +36,8 @@ pub enum Cmd { | ||||
|     HLen(String), | ||||
|     HMGet(String, Vec<String>), | ||||
|     HSetNx(String, String, String), | ||||
|     HIncrBy(String, String, i64), | ||||
|     HIncrByFloat(String, String, f64), | ||||
|     HScan(String, u64, Option<String>, Option<u64>), // key, cursor, pattern, count | ||||
|     Scan(u64, Option<String>, Option<u64>), // cursor, pattern, count | ||||
|     Ttl(String), | ||||
| @@ -101,14 +105,51 @@ impl Cmd { | ||||
|                         "ping" => Cmd::Ping, | ||||
|                         "get" => Cmd::Get(cmd[1].clone()), | ||||
|                         "set" => { | ||||
|                             if cmd.len() == 5 && cmd[3].to_lowercase() == "px" { | ||||
|                                 Cmd::SetPx(cmd[1].clone(), cmd[2].clone(), cmd[4].parse().unwrap()) | ||||
|                             } else if cmd.len() == 5 && cmd[3].to_lowercase() == "ex" { | ||||
|                                 Cmd::SetEx(cmd[1].clone(), cmd[2].clone(), cmd[4].parse().unwrap()) | ||||
|                             } else if cmd.len() == 3 { | ||||
|                                 Cmd::Set(cmd[1].clone(), cmd[2].clone()) | ||||
|                             if cmd.len() < 3 { | ||||
|                                 return Err(DBError("wrong number of arguments for SET".to_string())); | ||||
|                             } | ||||
|                             let key = cmd[1].clone(); | ||||
|                             let val = cmd[2].clone(); | ||||
|  | ||||
|                             // Parse optional flags: EX sec | PX ms | NX | XX | GET | ||||
|                             let mut ex_ms: Option<u128> = None; | ||||
|                             let mut nx = false; | ||||
|                             let mut xx = false; | ||||
|                             let mut getflag = false; | ||||
|  | ||||
|                             let mut i = 3; | ||||
|                             while i < cmd.len() { | ||||
|                                 match cmd[i].to_lowercase().as_str() { | ||||
|                                     "ex" => { | ||||
|                                         if i + 1 >= cmd.len() { | ||||
|                                             return Err(DBError("ERR syntax error".to_string())); | ||||
|                                         } | ||||
|                                         let secs: u128 = cmd[i + 1].parse().map_err(|_| DBError("ERR value is not an integer or out of range".to_string()))?; | ||||
|                                         ex_ms = Some(secs * 1000); | ||||
|                                         i += 2; | ||||
|                                     } | ||||
|                                     "px" => { | ||||
|                                         if i + 1 >= cmd.len() { | ||||
|                                             return Err(DBError("ERR syntax error".to_string())); | ||||
|                                         } | ||||
|                                         let ms: u128 = cmd[i + 1].parse().map_err(|_| DBError("ERR value is not an integer or out of range".to_string()))?; | ||||
|                                         ex_ms = Some(ms); | ||||
|                                         i += 2; | ||||
|                                     } | ||||
|                                     "nx" => { nx = true; i += 1; } | ||||
|                                     "xx" => { xx = true; i += 1; } | ||||
|                                     "get" => { getflag = true; i += 1; } | ||||
|                                     _ => { | ||||
|                                         return Err(DBError(format!("unsupported cmd {:?}", cmd))); | ||||
|                                     } | ||||
|                                 } | ||||
|                             } | ||||
|  | ||||
|                             // If no options, keep legacy behavior | ||||
|                             if ex_ms.is_none() && !nx && !xx && !getflag { | ||||
|                                 Cmd::Set(key, val) | ||||
|                             } else { | ||||
|                                 return Err(DBError(format!("unsupported cmd {:?}", cmd))); | ||||
|                                 Cmd::SetOpts(key, val, ex_ms, nx, xx, getflag) | ||||
|                             } | ||||
|                         } | ||||
|                         "setex" => { | ||||
| @@ -259,6 +300,20 @@ impl Cmd { | ||||
|                             } | ||||
|                             Cmd::HSetNx(cmd[1].clone(), cmd[2].clone(), cmd[3].clone()) | ||||
|                         } | ||||
|                         "hincrby" => { | ||||
|                             if cmd.len() != 4 { | ||||
|                                 return Err(DBError(format!("wrong number of arguments for HINCRBY command"))); | ||||
|                             } | ||||
|                             let delta = cmd[3].parse::<i64>().map_err(|_| DBError("ERR value is not an integer or out of range".to_string()))?; | ||||
|                             Cmd::HIncrBy(cmd[1].clone(), cmd[2].clone(), delta) | ||||
|                         } | ||||
|                         "hincrbyfloat" => { | ||||
|                             if cmd.len() != 4 { | ||||
|                                 return Err(DBError(format!("wrong number of arguments for HINCRBYFLOAT command"))); | ||||
|                             } | ||||
|                             let delta = cmd[3].parse::<f64>().map_err(|_| DBError("ERR value is not a valid float".to_string()))?; | ||||
|                             Cmd::HIncrByFloat(cmd[1].clone(), cmd[2].clone(), delta) | ||||
|                         } | ||||
|                         "hscan" => { | ||||
|                             if cmd.len() < 3 { | ||||
|                                 return Err(DBError(format!("wrong number of arguments for HSCAN command"))); | ||||
| @@ -560,6 +615,7 @@ impl Cmd { | ||||
|             Cmd::Set(k, v) => set_cmd(server, &k, &v).await, | ||||
|             Cmd::SetPx(k, v, x) => set_px_cmd(server, &k, &v, &x).await, | ||||
|             Cmd::SetEx(k, v, x) => set_ex_cmd(server, &k, &v, &x).await, | ||||
|             Cmd::SetOpts(k, v, ex_ms, nx, xx, getflag) => set_with_opts_cmd(server, &k, &v, ex_ms, nx, xx, getflag).await, | ||||
|             Cmd::MGet(keys) => mget_cmd(server, &keys).await, | ||||
|             Cmd::MSet(pairs) => mset_cmd(server, &pairs).await, | ||||
|             Cmd::Del(k) => del_cmd(server, &k).await, | ||||
| @@ -593,6 +649,8 @@ impl Cmd { | ||||
|             Cmd::HLen(key) => hlen_cmd(server, &key).await, | ||||
|             Cmd::HMGet(key, fields) => hmget_cmd(server, &key, &fields).await, | ||||
|             Cmd::HSetNx(key, field, value) => hsetnx_cmd(server, &key, &field, &value).await, | ||||
|             Cmd::HIncrBy(key, field, delta) => hincrby_cmd(server, &key, &field, delta).await, | ||||
|             Cmd::HIncrByFloat(key, field, delta) => hincrbyfloat_cmd(server, &key, &field, delta).await, | ||||
|             Cmd::HScan(key, cursor, pattern, count) => hscan_cmd(server, &key, &cursor, pattern.as_deref(), &count).await, | ||||
|             Cmd::Scan(cursor, pattern, count) => scan_cmd(server, &cursor, pattern.as_deref(), &count).await, | ||||
|             Cmd::Ttl(key) => ttl_cmd(server, &key).await, | ||||
| @@ -981,6 +1039,62 @@ async fn set_cmd(server: &Server, k: &str, v: &str) -> Result<Protocol, DBError> | ||||
|     Ok(Protocol::SimpleString("OK".to_string())) | ||||
| } | ||||
|  | ||||
| // Advanced SET with options: EX/PX/NX/XX/GET | ||||
| async fn set_with_opts_cmd( | ||||
|     server: &Server, | ||||
|     key: &str, | ||||
|     value: &str, | ||||
|     ex_ms: Option<u128>, | ||||
|     nx: bool, | ||||
|     xx: bool, | ||||
|     get_old: bool, | ||||
| ) -> Result<Protocol, DBError> { | ||||
|     let storage = server.current_storage()?; | ||||
|  | ||||
|     // Determine existence (for NX/XX) | ||||
|     let exists = storage.exists(key)?; | ||||
|  | ||||
|     // If both NX and XX, condition can never be satisfied -> no-op | ||||
|     let mut should_set = true; | ||||
|     if nx && exists { | ||||
|         should_set = false; | ||||
|     } | ||||
|     if xx && !exists { | ||||
|         should_set = false; | ||||
|     } | ||||
|  | ||||
|     // Fetch old value if needed for GET | ||||
|     let old_val = if get_old { | ||||
|         storage.get(key)? | ||||
|     } else { | ||||
|         None | ||||
|     }; | ||||
|  | ||||
|     if should_set { | ||||
|         if let Some(ms) = ex_ms { | ||||
|             storage.setx(key.to_string(), value.to_string(), ms)?; | ||||
|         } else { | ||||
|             storage.set(key.to_string(), value.to_string())?; | ||||
|         } | ||||
|     } | ||||
|  | ||||
|     if get_old { | ||||
|         // Return previous value (or Null), regardless of NX/XX outcome only if set executed? | ||||
|         // We follow Redis semantics: return old value if set executed, else Null | ||||
|         if should_set { | ||||
|             Ok(old_val.map_or(Protocol::Null, Protocol::BulkString)) | ||||
|         } else { | ||||
|             Ok(Protocol::Null) | ||||
|         } | ||||
|     } else { | ||||
|         if should_set { | ||||
|             Ok(Protocol::SimpleString("OK".to_string())) | ||||
|         } else { | ||||
|             Ok(Protocol::Null) | ||||
|         } | ||||
|     } | ||||
| } | ||||
|  | ||||
| // MGET: return array of bulk strings or Null for missing | ||||
| async fn mget_cmd(server: &Server, keys: &[String]) -> Result<Protocol, DBError> { | ||||
|     let mut out: Vec<Protocol> = Vec::with_capacity(keys.len()); | ||||
| @@ -1120,6 +1234,32 @@ async fn hsetnx_cmd(server: &Server, key: &str, field: &str, value: &str) -> Res | ||||
|     } | ||||
| } | ||||
|  | ||||
| async fn hincrby_cmd(server: &Server, key: &str, field: &str, delta: i64) -> Result<Protocol, DBError> { | ||||
|     let storage = server.current_storage()?; | ||||
|     let current = storage.hget(key, field)?; | ||||
|     let base: i64 = match current { | ||||
|         Some(v) => v.parse::<i64>().map_err(|_| DBError("ERR value is not an integer or out of range".to_string()))?, | ||||
|         None => 0, | ||||
|     }; | ||||
|     let new_val = base.checked_add(delta).ok_or_else(|| DBError("ERR increment or decrement would overflow".to_string()))?; | ||||
|     // Update the field | ||||
|     storage.hset(key, vec![(field.to_string(), new_val.to_string())])?; | ||||
|     Ok(Protocol::SimpleString(new_val.to_string())) | ||||
| } | ||||
|  | ||||
| async fn hincrbyfloat_cmd(server: &Server, key: &str, field: &str, delta: f64) -> Result<Protocol, DBError> { | ||||
|     let storage = server.current_storage()?; | ||||
|     let current = storage.hget(key, field)?; | ||||
|     let base: f64 = match current { | ||||
|         Some(v) => v.parse::<f64>().map_err(|_| DBError("ERR value is not a valid float".to_string()))?, | ||||
|         None => 0.0, | ||||
|     }; | ||||
|     let new_val = base + delta; | ||||
|     // Update the field | ||||
|     storage.hset(key, vec![(field.to_string(), new_val.to_string())])?; | ||||
|     Ok(Protocol::SimpleString(new_val.to_string())) | ||||
| } | ||||
|  | ||||
| async fn scan_cmd( | ||||
|     server: &Server, | ||||
|     cursor: &u64, | ||||
|   | ||||
| @@ -19,6 +19,10 @@ impl fmt::Display for Protocol { | ||||
|  | ||||
| impl Protocol { | ||||
|     pub fn from(protocol: &str) -> Result<(Self, &str), DBError> { | ||||
|         if protocol.is_empty() { | ||||
|             // Incomplete frame; caller should read more bytes | ||||
|             return Err(DBError("[incomplete] empty".to_string())); | ||||
|         } | ||||
|         let ret = match protocol.chars().nth(0) { | ||||
|             Some('+') => Self::parse_simple_string_sfx(&protocol[1..]), | ||||
|             Some('$') => Self::parse_bulk_string_sfx(&protocol[1..]), | ||||
| @@ -101,21 +105,20 @@ impl Protocol { | ||||
|             let size = Self::parse_usize(&protocol[..len_end])?; | ||||
|             let data_start = len_end + 2; | ||||
|             let data_end = data_start + size; | ||||
|             let s = Self::parse_string(&protocol[data_start..data_end])?; | ||||
|  | ||||
|             if protocol.len() < data_end + 2 || &protocol[data_end..data_end+2] != "\r\n" { | ||||
|                  Err(DBError(format!( | ||||
|                     "[new bulk string] unmatched string length in prototocl {:?}", | ||||
|                     protocol, | ||||
|                 ))) | ||||
|             } else { | ||||
|                 Ok((Protocol::BulkString(s), &protocol[data_end + 2..])) | ||||
|             // If we don't yet have the full bulk payload + trailing CRLF, signal INCOMPLETE | ||||
|             if protocol.len() < data_end + 2 { | ||||
|                 return Err(DBError("[incomplete] bulk body".to_string())); | ||||
|             } | ||||
|             if &protocol[data_end..data_end + 2] != "\r\n" { | ||||
|                 return Err(DBError("[incomplete] bulk terminator".to_string())); | ||||
|             } | ||||
|  | ||||
|             let s = Self::parse_string(&protocol[data_start..data_end])?; | ||||
|             Ok((Protocol::BulkString(s), &protocol[data_end + 2..])) | ||||
|         } else { | ||||
|             Err(DBError(format!( | ||||
|                 "[new bulk string] unsupported protocol: {:?}", | ||||
|                 protocol | ||||
|             ))) | ||||
|             // No CRLF after bulk length header yet | ||||
|             Err(DBError("[incomplete] bulk header".to_string())) | ||||
|         } | ||||
|     } | ||||
|  | ||||
| @@ -125,16 +128,25 @@ impl Protocol { | ||||
|             let mut remaining = &s[len_end + 2..]; | ||||
|             let mut vec = vec![]; | ||||
|             for _ in 0..array_len { | ||||
|                 let (p, rem) = Protocol::from(remaining)?; | ||||
|                 vec.push(p); | ||||
|                 remaining = rem; | ||||
|                 match Protocol::from(remaining) { | ||||
|                     Ok((p, rem)) => { | ||||
|                         vec.push(p); | ||||
|                         remaining = rem; | ||||
|                     } | ||||
|                     Err(e) => { | ||||
|                         // Propagate incomplete so caller can read more bytes | ||||
|                         if e.0.starts_with("[incomplete]") { | ||||
|                             return Err(e); | ||||
|                         } else { | ||||
|                             return Err(e); | ||||
|                         } | ||||
|                     } | ||||
|                 } | ||||
|             } | ||||
|             Ok((Protocol::Array(vec), remaining)) | ||||
|         } else { | ||||
|             Err(DBError(format!( | ||||
|                 "[new array] unsupported protocol: {:?}", | ||||
|                 s | ||||
|             ))) | ||||
|             // No CRLF after array header yet | ||||
|             Err(DBError("[incomplete] array header".to_string())) | ||||
|         } | ||||
|     } | ||||
|  | ||||
|   | ||||
| @@ -160,31 +160,40 @@ impl Server { | ||||
|         &mut self, | ||||
|         mut stream: tokio::net::TcpStream, | ||||
|     ) -> Result<(), DBError> { | ||||
|         let mut buf = [0; 512]; | ||||
|          | ||||
|         // Accumulate incoming bytes to handle partial RESP frames | ||||
|         let mut acc = String::new(); | ||||
|         let mut buf = vec![0u8; 8192]; | ||||
|  | ||||
|         loop { | ||||
|             let len = match stream.read(&mut buf).await { | ||||
|             let n = match stream.read(&mut buf).await { | ||||
|                 Ok(0) => { | ||||
|                     println!("[handle] connection closed"); | ||||
|                     return Ok(()); | ||||
|                 } | ||||
|                 Ok(len) => len, | ||||
|                 Ok(n) => n, | ||||
|                 Err(e) => { | ||||
|                     println!("[handle] read error: {:?}", e); | ||||
|                     return Err(e.into()); | ||||
|                 } | ||||
|             }; | ||||
|  | ||||
|             let mut s = str::from_utf8(&buf[..len])?; | ||||
|             while !s.is_empty() { | ||||
|                 let (cmd, protocol, remaining) = match Cmd::from(s) { | ||||
|             // Append to accumulator. RESP for our usage is ASCII-safe. | ||||
|             acc.push_str(str::from_utf8(&buf[..n])?); | ||||
|  | ||||
|             // Try to parse as many complete commands as are available in 'acc'. | ||||
|             loop { | ||||
|                 let parsed = Cmd::from(&acc); | ||||
|                 let (cmd, protocol, remaining) = match parsed { | ||||
|                     Ok((cmd, protocol, remaining)) => (cmd, protocol, remaining), | ||||
|                     Err(e) => { | ||||
|                         println!("\x1b[31;1mprotocol error: {:?}\x1b[0m", e); | ||||
|                         (Cmd::Unknow("protocol_error".to_string()), Protocol::err(&format!("protocol error: {}", e.0)), "") | ||||
|                     Err(_e) => { | ||||
|                         // Incomplete or invalid frame; assume incomplete and wait for more data. | ||||
|                         // This avoids emitting spurious protocol_error for split frames. | ||||
|                         break; | ||||
|                     } | ||||
|                 }; | ||||
|                 s = remaining; | ||||
|  | ||||
|                 // Advance the accumulator to the unparsed remainder | ||||
|                 acc = remaining.to_string(); | ||||
|  | ||||
|                 if self.option.debug { | ||||
|                     println!("\x1b[34;1mgot command: {:?}, protocol: {:?}\x1b[0m", cmd, protocol); | ||||
| @@ -204,7 +213,7 @@ impl Server { | ||||
|                         Protocol::err(&format!("ERR {}", e.0)) | ||||
|                     } | ||||
|                 }; | ||||
|                  | ||||
|  | ||||
|                 if self.option.debug { | ||||
|                     println!("\x1b[34;1mqueued cmd {:?}\x1b[0m", self.queued_cmd); | ||||
|                     println!("\x1b[32;1mgoing to send response {}\x1b[0m", res.encode()); | ||||
| @@ -220,6 +229,11 @@ impl Server { | ||||
|                     println!("[handle] QUIT command received, closing connection"); | ||||
|                     return Ok(()); | ||||
|                 } | ||||
|  | ||||
|                 // Continue parsing any further complete commands already in 'acc' | ||||
|                 if acc.is_empty() { | ||||
|                     break; | ||||
|                 } | ||||
|             } | ||||
|         } | ||||
|     } | ||||
|   | ||||
| @@ -70,14 +70,107 @@ async fn connect(port: u16) -> TcpStream { | ||||
|     } | ||||
| } | ||||
|  | ||||
| fn find_crlf(buf: &[u8], start: usize) -> Option<usize> { | ||||
|     let mut i = start; | ||||
|     while i + 1 < buf.len() { | ||||
|         if buf[i] == b'\r' && buf[i + 1] == b'\n' { | ||||
|             return Some(i); | ||||
|         } | ||||
|         i += 1; | ||||
|     } | ||||
|     None | ||||
| } | ||||
|  | ||||
| fn parse_number_i64(buf: &[u8], start: usize, end: usize) -> Option<i64> { | ||||
|     let s = std::str::from_utf8(&buf[start..end]).ok()?; | ||||
|     s.parse::<i64>().ok() | ||||
| } | ||||
|  | ||||
| // Return number of bytes that make up a complete RESP element starting at 'i', or None if incomplete. | ||||
| fn parse_elem(buf: &[u8], i: usize) -> Option<usize> { | ||||
|     if i >= buf.len() { | ||||
|         return None; | ||||
|     } | ||||
|     match buf[i] { | ||||
|         b'+' | b'-' | b':' => { | ||||
|             let end = find_crlf(buf, i + 1)?; | ||||
|             Some(end + 2 - i) | ||||
|         } | ||||
|         b'$' => { | ||||
|             let hdr_end = find_crlf(buf, i + 1)?; | ||||
|             let n = parse_number_i64(buf, i + 1, hdr_end)?; | ||||
|             if n < 0 { | ||||
|                 // Null bulk string: only header | ||||
|                 Some(hdr_end + 2 - i) | ||||
|             } else { | ||||
|                 let need = hdr_end + 2 + (n as usize) + 2; | ||||
|                 if need <= buf.len() { | ||||
|                     Some(need - i) | ||||
|                 } else { | ||||
|                     None | ||||
|                 } | ||||
|             } | ||||
|         } | ||||
|         b'*' => { | ||||
|             let hdr_end = find_crlf(buf, i + 1)?; | ||||
|             let n = parse_number_i64(buf, i + 1, hdr_end)?; | ||||
|             if n < 0 { | ||||
|                 // Null array: only header | ||||
|                 Some(hdr_end + 2 - i) | ||||
|             } else { | ||||
|                 let mut j = hdr_end + 2; | ||||
|                 for _ in 0..(n as usize) { | ||||
|                     let consumed = parse_elem(buf, j)?; | ||||
|                     j += consumed; | ||||
|                 } | ||||
|                 Some(j - i) | ||||
|             } | ||||
|         } | ||||
|         _ => None, | ||||
|     } | ||||
| } | ||||
|  | ||||
| fn resp_frame_len(buf: &[u8]) -> Option<usize> { | ||||
|     parse_elem(buf, 0) | ||||
| } | ||||
|  | ||||
| async fn read_full_resp(stream: &mut TcpStream) -> String { | ||||
|     let mut buf: Vec<u8> = Vec::with_capacity(8192); | ||||
|     let mut tmp = vec![0u8; 4096]; | ||||
|  | ||||
|     loop { | ||||
|         if let Some(total) = resp_frame_len(&buf) { | ||||
|             if buf.len() >= total { | ||||
|                 return String::from_utf8_lossy(&buf[..total]).to_string(); | ||||
|             } | ||||
|         } | ||||
|  | ||||
|         match tokio::time::timeout(Duration::from_secs(2), stream.read(&mut tmp)).await { | ||||
|             Ok(Ok(n)) => { | ||||
|                 if n == 0 { | ||||
|                     if let Some(total) = resp_frame_len(&buf) { | ||||
|                         if buf.len() >= total { | ||||
|                             return String::from_utf8_lossy(&buf[..total]).to_string(); | ||||
|                         } | ||||
|                     } | ||||
|                     return String::from_utf8_lossy(&buf).to_string(); | ||||
|                 } | ||||
|                 buf.extend_from_slice(&tmp[..n]); | ||||
|             } | ||||
|             Ok(Err(e)) => panic!("read error: {}", e), | ||||
|             Err(_) => panic!("timeout waiting for reply"), | ||||
|         } | ||||
|  | ||||
|         if buf.len() > 8 * 1024 * 1024 { | ||||
|             panic!("reply too large"); | ||||
|         } | ||||
|     } | ||||
| } | ||||
|  | ||||
| async fn send_cmd(stream: &mut TcpStream, args: &[&str]) -> String { | ||||
|     let req = build_resp(args); | ||||
|     stream.write_all(req.as_bytes()).await.unwrap(); | ||||
|  | ||||
|     // Single read is enough for these small replies | ||||
|     let mut buf = vec![0u8; 8192]; | ||||
|     let n = stream.read(&mut buf).await.unwrap(); | ||||
|     String::from_utf8_lossy(&buf[..n]).to_string() | ||||
|     read_full_resp(stream).await | ||||
| } | ||||
|  | ||||
| // Assert helpers with clearer output | ||||
| @@ -559,6 +652,58 @@ async fn test_10_expire_pexpire_persist() { | ||||
|    assert_contains(&persist2, "0", "PERSIST again -> 0 (no expiration to remove)"); | ||||
| } | ||||
|  | ||||
| #[tokio::test] | ||||
| async fn test_11_set_with_options() { | ||||
|     let (server, port) = start_test_server("set_opts").await; | ||||
|     spawn_listener(server, port).await; | ||||
|     sleep(Duration::from_millis(150)).await; | ||||
|  | ||||
|     let mut s = connect(port).await; | ||||
|  | ||||
|     // SET with GET on non-existing key -> returns Null, sets value | ||||
|     let set_get1 = send_cmd(&mut s, &["SET", "s1", "v1", "GET"]).await; | ||||
|     assert_contains(&set_get1, "$-1", "SET s1 v1 GET returns Null when key didn't exist"); | ||||
|     let g1 = send_cmd(&mut s, &["GET", "s1"]).await; | ||||
|     assert_contains(&g1, "v1", "GET s1 after first SET"); | ||||
|  | ||||
|     // SET with GET should return old value, then set to new | ||||
|     let set_get2 = send_cmd(&mut s, &["SET", "s1", "v2", "GET"]).await; | ||||
|     assert_contains(&set_get2, "v1", "SET s1 v2 GET returns previous value v1"); | ||||
|     let g2 = send_cmd(&mut s, &["GET", "s1"]).await; | ||||
|     assert_contains(&g2, "v2", "GET s1 now v2"); | ||||
|  | ||||
|     // NX prevents update when key exists; with GET should return Null and not change | ||||
|     let set_nx = send_cmd(&mut s, &["SET", "s1", "v3", "NX", "GET"]).await; | ||||
|     assert_contains(&set_nx, "$-1", "SET s1 v3 NX GET returns Null when not set"); | ||||
|     let g3 = send_cmd(&mut s, &["GET", "s1"]).await; | ||||
|     assert_contains(&g3, "v2", "GET s1 remains v2 after NX prevented write"); | ||||
|  | ||||
|     // NX allows set when key does not exist | ||||
|     let set_nx2 = send_cmd(&mut s, &["SET", "s2", "v10", "NX"]).await; | ||||
|     assert_contains(&set_nx2, "OK", "SET s2 v10 NX -> OK for new key"); | ||||
|     let g4 = send_cmd(&mut s, &["GET", "s2"]).await; | ||||
|     assert_contains(&g4, "v10", "GET s2 is v10"); | ||||
|  | ||||
|     // XX requires existing key; with GET returns old value and sets new | ||||
|     let set_xx = send_cmd(&mut s, &["SET", "s2", "v11", "XX", "GET"]).await; | ||||
|     assert_contains(&set_xx, "v10", "SET s2 v11 XX GET returns previous v10"); | ||||
|     let g5 = send_cmd(&mut s, &["GET", "s2"]).await; | ||||
|     assert_contains(&g5, "v11", "GET s2 is now v11"); | ||||
|  | ||||
|     // PX expiration path via SET options | ||||
|     let set_px = send_cmd(&mut s, &["SET", "s3", "vpx", "PX", "500"]).await; | ||||
|     assert_contains(&set_px, "OK", "SET s3 vpx PX 500 -> OK"); | ||||
|     let ttl_px1 = send_cmd(&mut s, &["TTL", "s3"]).await; | ||||
|     assert!( | ||||
|         ttl_px1.contains("0") || ttl_px1.contains("1"), | ||||
|         "TTL s3 immediately after PX should be 1 or 0, got: {}", | ||||
|         ttl_px1 | ||||
|     ); | ||||
|     sleep(Duration::from_millis(650)).await; | ||||
|     let g6 = send_cmd(&mut s, &["GET", "s3"]).await; | ||||
|     assert_contains(&g6, "$-1", "GET s3 after PX expiry -> Null"); | ||||
| } | ||||
|  | ||||
| #[tokio::test] | ||||
| async fn test_09_mget_mset_and_variadic_exists_del() { | ||||
|    let (server, port) = start_test_server("mget_mset_variadic").await; | ||||
| @@ -597,4 +742,41 @@ async fn test_09_mget_mset_and_variadic_exists_del() { | ||||
|    assert_contains(&mget_after, "$-1", "MGET k1 after DEL -> Null"); | ||||
|    assert_contains(&mget_after, "v2", "MGET k2 remains"); | ||||
|    assert_contains(&mget_after, "$-1", "MGET k3 after DEL -> Null"); | ||||
| } | ||||
| #[tokio::test] | ||||
| async fn test_12_hash_incr() { | ||||
|     let (server, port) = start_test_server("hash_incr").await; | ||||
|     spawn_listener(server, port).await; | ||||
|     sleep(Duration::from_millis(150)).await; | ||||
|  | ||||
|     let mut s = connect(port).await; | ||||
|  | ||||
|     // Integer increments | ||||
|     let _ = send_cmd(&mut s, &["HSET", "hinc", "a", "1"]).await; | ||||
|     let r1 = send_cmd(&mut s, &["HINCRBY", "hinc", "a", "2"]).await; | ||||
|     assert_contains(&r1, "3", "HINCRBY hinc a 2 -> 3"); | ||||
|  | ||||
|     let r2 = send_cmd(&mut s, &["HINCRBY", "hinc", "a", "-1"]).await; | ||||
|     assert_contains(&r2, "2", "HINCRBY hinc a -1 -> 2"); | ||||
|  | ||||
|     let r3 = send_cmd(&mut s, &["HINCRBY", "hinc", "b", "5"]).await; | ||||
|     assert_contains(&r3, "5", "HINCRBY hinc b 5 -> 5"); | ||||
|  | ||||
|     // HINCRBY error on non-integer field | ||||
|     let _ = send_cmd(&mut s, &["HSET", "hinc", "s", "x"]).await; | ||||
|     let r_err = send_cmd(&mut s, &["HINCRBY", "hinc", "s", "1"]).await; | ||||
|     assert_contains(&r_err, "ERR", "HINCRBY on non-integer field should ERR"); | ||||
|  | ||||
|     // Float increments | ||||
|     let r4 = send_cmd(&mut s, &["HINCRBYFLOAT", "hinc", "f", "1.5"]).await; | ||||
|     assert_contains(&r4, "1.5", "HINCRBYFLOAT hinc f 1.5 -> 1.5"); | ||||
|  | ||||
|     let r5 = send_cmd(&mut s, &["HINCRBYFLOAT", "hinc", "f", "2.5"]).await; | ||||
|     // Could be "4", "4.0", or "4.000000", accept "4" substring | ||||
|     assert_contains(&r5, "4", "HINCRBYFLOAT hinc f 2.5 -> 4"); | ||||
|  | ||||
|     // HINCRBYFLOAT error on non-float field | ||||
|     let _ = send_cmd(&mut s, &["HSET", "hinc", "notf", "abc"]).await; | ||||
|     let r6 = send_cmd(&mut s, &["HINCRBYFLOAT", "hinc", "notf", "1"]).await; | ||||
|     assert_contains(&r6, "ERR", "HINCRBYFLOAT on non-float field should ERR"); | ||||
| } | ||||
		Reference in New Issue
	
	Block a user