diff --git a/src/cmd.rs b/src/cmd.rs index 437331b..fcaa71e 100644 --- a/src/cmd.rs +++ b/src/cmd.rs @@ -28,6 +28,7 @@ pub enum Cmd { HLen(String), HMGet(String, Vec), HSetNx(String, String, String), + Scan(u64, Option, Option), // cursor, pattern, count Unknow, } @@ -176,6 +177,43 @@ impl Cmd { } Cmd::HSetNx(cmd[1].clone(), cmd[2].clone(), cmd[3].clone()) } + "scan" => { + if cmd.len() < 2 { + return Err(DBError(format!("wrong number of arguments for SCAN command"))); + } + + let cursor = cmd[1].parse::().map_err(|_| + DBError("ERR invalid cursor".to_string()))?; + + let mut pattern = None; + let mut count = None; + let mut i = 2; + + while i < cmd.len() { + match cmd[i].to_lowercase().as_str() { + "match" => { + if i + 1 >= cmd.len() { + return Err(DBError("ERR syntax error".to_string())); + } + pattern = Some(cmd[i + 1].clone()); + i += 2; + } + "count" => { + if i + 1 >= cmd.len() { + return Err(DBError("ERR syntax error".to_string())); + } + count = Some(cmd[i + 1].parse::().map_err(|_| + DBError("ERR value is not an integer or out of range".to_string()))?); + i += 2; + } + _ => { + return Err(DBError(format!("ERR syntax error"))); + } + } + } + + Cmd::Scan(cursor, pattern, count) + } _ => Cmd::Unknow, }, protocol.0, @@ -244,6 +282,7 @@ impl Cmd { Cmd::HLen(key) => hlen_cmd(server, key).await, Cmd::HMGet(key, fields) => hmget_cmd(server, key, fields).await, Cmd::HSetNx(key, field, value) => hsetnx_cmd(server, key, field, value).await, + Cmd::Scan(cursor, pattern, count) => scan_cmd(server, cursor, pattern.as_deref(), count).await, Cmd::Unknow => Ok(Protocol::err("unknown cmd")), } } @@ -444,3 +483,17 @@ async fn hsetnx_cmd(server: &Server, key: &str, field: &str, value: &str) -> Res Err(e) => Ok(Protocol::err(&e.0)), } } + +async fn scan_cmd(server: &Server, cursor: &u64, pattern: Option<&str>, count: &Option) -> Result { + match server.storage.scan(*cursor, pattern, *count) { + Ok((next_cursor, keys)) => { + let mut result = Vec::new(); + result.push(Protocol::BulkString(next_cursor.to_string())); + result.push(Protocol::Array( + keys.into_iter().map(Protocol::BulkString).collect(), + )); + Ok(Protocol::Array(result)) + } + Err(e) => Ok(Protocol::err(&e.0)), + } +} diff --git a/src/storage.rs b/src/storage.rs index 77c66ca..b738ce4 100644 --- a/src/storage.rs +++ b/src/storage.rs @@ -445,4 +445,65 @@ impl Storage { write_txn.commit()?; Ok(result) } + + pub fn scan(&self, cursor: u64, pattern: Option<&str>, count: Option) -> Result<(u64, Vec), DBError> { + let read_txn = self.db.begin_read()?; + let table = read_txn.open_table(TYPES_TABLE)?; + + let count = count.unwrap_or(10); // Default count is 10 + let mut keys = Vec::new(); + let mut current_cursor = 0u64; + let mut returned_keys = 0u64; + + let mut iter = table.iter()?; + while let Some(entry) = iter.next() { + let key = entry?.0.value().to_string(); + + // Skip keys until we reach the cursor position + if current_cursor < cursor { + current_cursor += 1; + continue; + } + + // Check if key matches pattern + let matches = match pattern { + Some(pat) => { + if pat == "*" { + true + } else if pat.contains('*') { + // Simple glob pattern matching + let pattern_parts: Vec<&str> = pat.split('*').collect(); + if pattern_parts.len() == 2 { + let prefix = pattern_parts[0]; + let suffix = pattern_parts[1]; + key.starts_with(prefix) && key.ends_with(suffix) + } else { + key.contains(&pat.replace('*', "")) + } + } else { + key.contains(pat) + } + } + None => true, + }; + + if matches { + keys.push(key); + returned_keys += 1; + + // Stop if we've returned enough keys + if returned_keys >= count { + current_cursor += 1; + break; + } + } + + current_cursor += 1; + } + + // If we've reached the end of iteration, return cursor 0 to indicate completion + let next_cursor = if returned_keys < count { 0 } else { current_cursor }; + + Ok((next_cursor, keys)) + } } diff --git a/test_herodb.sh b/test_herodb.sh index fb96a6c..b3552d4 100755 --- a/test_herodb.sh +++ b/test_herodb.sh @@ -230,6 +230,58 @@ test_expiration() { redis_cmd "GET expire_ex_key" "" # Should be expired } +# Function to test SCAN operations +test_scan_operations() { + print_status "=== Testing SCAN Operations ===" + + # Set up test data for scanning + redis_cmd "SET scan_test1 value1" "OK" + redis_cmd "SET scan_test2 value2" "OK" + redis_cmd "SET scan_test3 value3" "OK" + redis_cmd "SET other_key other_value" "OK" + redis_cmd "HSET scan_hash field1 value1" "1" + + # Test basic SCAN + print_status "Testing basic SCAN with cursor 0" + redis_cmd "SCAN 0" "" + + # Test SCAN with MATCH pattern + print_status "Testing SCAN with MATCH pattern" + redis_cmd "SCAN 0 MATCH scan_test*" "" + + # Test SCAN with COUNT + print_status "Testing SCAN with COUNT 2" + redis_cmd "SCAN 0 COUNT 2" "" + + # Test SCAN with both MATCH and COUNT + print_status "Testing SCAN with MATCH and COUNT" + redis_cmd "SCAN 0 MATCH scan_* COUNT 1" "" + + # Test SCAN continuation with more keys + print_status "Setting up more keys for continuation test" + redis_cmd "SET scan_key1 val1" "OK" + redis_cmd "SET scan_key2 val2" "OK" + redis_cmd "SET scan_key3 val3" "OK" + redis_cmd "SET scan_key4 val4" "OK" + redis_cmd "SET scan_key5 val5" "OK" + + print_status "Testing SCAN with small COUNT for pagination" + redis_cmd "SCAN 0 COUNT 3" "" + + # Clean up SCAN test data + print_status "Cleaning up SCAN test data" + redis_cmd "DEL scan_test1" "1" + redis_cmd "DEL scan_test2" "1" + redis_cmd "DEL scan_test3" "1" + redis_cmd "DEL other_key" "1" + redis_cmd "DEL scan_hash" "1" + redis_cmd "DEL scan_key1" "1" + redis_cmd "DEL scan_key2" "1" + redis_cmd "DEL scan_key3" "1" + redis_cmd "DEL scan_key4" "1" + redis_cmd "DEL scan_key5" "1" +} + # Main execution main() { print_status "Starting HeroDB comprehensive test suite..." @@ -265,6 +317,7 @@ main() { test_keys_operations || failed_tests=$((failed_tests + 1)) test_info_operations || failed_tests=$((failed_tests + 1)) test_expiration || failed_tests=$((failed_tests + 1)) + test_scan_operations || failed_tests=$((failed_tests + 1)) # Summary echo