...
This commit is contained in:
		
							
								
								
									
										53
									
								
								src/cmd.rs
									
									
									
									
									
								
							
							
						
						
									
										53
									
								
								src/cmd.rs
									
									
									
									
									
								
							| @@ -28,6 +28,7 @@ pub enum Cmd { | ||||
|     HLen(String), | ||||
|     HMGet(String, Vec<String>), | ||||
|     HSetNx(String, String, String), | ||||
|     Scan(u64, Option<String>, Option<u64>), // cursor, pattern, count | ||||
|     Unknow, | ||||
| } | ||||
|  | ||||
| @@ -176,6 +177,43 @@ impl Cmd { | ||||
|                             } | ||||
|                             Cmd::HSetNx(cmd[1].clone(), cmd[2].clone(), cmd[3].clone()) | ||||
|                         } | ||||
|                         "scan" => { | ||||
|                             if cmd.len() < 2 { | ||||
|                                 return Err(DBError(format!("wrong number of arguments for SCAN command"))); | ||||
|                             } | ||||
|                              | ||||
|                             let cursor = cmd[1].parse::<u64>().map_err(|_| | ||||
|                                 DBError("ERR invalid cursor".to_string()))?; | ||||
|                              | ||||
|                             let mut pattern = None; | ||||
|                             let mut count = None; | ||||
|                             let mut i = 2; | ||||
|                              | ||||
|                             while i < cmd.len() { | ||||
|                                 match cmd[i].to_lowercase().as_str() { | ||||
|                                     "match" => { | ||||
|                                         if i + 1 >= cmd.len() { | ||||
|                                             return Err(DBError("ERR syntax error".to_string())); | ||||
|                                         } | ||||
|                                         pattern = Some(cmd[i + 1].clone()); | ||||
|                                         i += 2; | ||||
|                                     } | ||||
|                                     "count" => { | ||||
|                                         if i + 1 >= cmd.len() { | ||||
|                                             return Err(DBError("ERR syntax error".to_string())); | ||||
|                                         } | ||||
|                                         count = Some(cmd[i + 1].parse::<u64>().map_err(|_| | ||||
|                                             DBError("ERR value is not an integer or out of range".to_string()))?); | ||||
|                                         i += 2; | ||||
|                                     } | ||||
|                                     _ => { | ||||
|                                         return Err(DBError(format!("ERR syntax error"))); | ||||
|                                     } | ||||
|                                 } | ||||
|                             } | ||||
|                              | ||||
|                             Cmd::Scan(cursor, pattern, count) | ||||
|                         } | ||||
|                         _ => Cmd::Unknow, | ||||
|                     }, | ||||
|                     protocol.0, | ||||
| @@ -244,6 +282,7 @@ impl Cmd { | ||||
|             Cmd::HLen(key) => hlen_cmd(server, key).await, | ||||
|             Cmd::HMGet(key, fields) => hmget_cmd(server, key, fields).await, | ||||
|             Cmd::HSetNx(key, field, value) => hsetnx_cmd(server, key, field, value).await, | ||||
|             Cmd::Scan(cursor, pattern, count) => scan_cmd(server, cursor, pattern.as_deref(), count).await, | ||||
|             Cmd::Unknow => Ok(Protocol::err("unknown cmd")), | ||||
|         } | ||||
|     } | ||||
| @@ -444,3 +483,17 @@ async fn hsetnx_cmd(server: &Server, key: &str, field: &str, value: &str) -> Res | ||||
|         Err(e) => Ok(Protocol::err(&e.0)), | ||||
|     } | ||||
| } | ||||
|  | ||||
| async fn scan_cmd(server: &Server, cursor: &u64, pattern: Option<&str>, count: &Option<u64>) -> Result<Protocol, DBError> { | ||||
|     match server.storage.scan(*cursor, pattern, *count) { | ||||
|         Ok((next_cursor, keys)) => { | ||||
|             let mut result = Vec::new(); | ||||
|             result.push(Protocol::BulkString(next_cursor.to_string())); | ||||
|             result.push(Protocol::Array( | ||||
|                 keys.into_iter().map(Protocol::BulkString).collect(), | ||||
|             )); | ||||
|             Ok(Protocol::Array(result)) | ||||
|         } | ||||
|         Err(e) => Ok(Protocol::err(&e.0)), | ||||
|     } | ||||
| } | ||||
|   | ||||
| @@ -445,4 +445,65 @@ impl Storage { | ||||
|         write_txn.commit()?; | ||||
|         Ok(result) | ||||
|     } | ||||
|  | ||||
|     pub fn scan(&self, cursor: u64, pattern: Option<&str>, count: Option<u64>) -> Result<(u64, Vec<String>), DBError> { | ||||
|         let read_txn = self.db.begin_read()?; | ||||
|         let table = read_txn.open_table(TYPES_TABLE)?; | ||||
|          | ||||
|         let count = count.unwrap_or(10); // Default count is 10 | ||||
|         let mut keys = Vec::new(); | ||||
|         let mut current_cursor = 0u64; | ||||
|         let mut returned_keys = 0u64; | ||||
|          | ||||
|         let mut iter = table.iter()?; | ||||
|         while let Some(entry) = iter.next() { | ||||
|             let key = entry?.0.value().to_string(); | ||||
|              | ||||
|             // Skip keys until we reach the cursor position | ||||
|             if current_cursor < cursor { | ||||
|                 current_cursor += 1; | ||||
|                 continue; | ||||
|             } | ||||
|              | ||||
|             // Check if key matches pattern | ||||
|             let matches = match pattern { | ||||
|                 Some(pat) => { | ||||
|                     if pat == "*" { | ||||
|                         true | ||||
|                     } else if pat.contains('*') { | ||||
|                         // Simple glob pattern matching | ||||
|                         let pattern_parts: Vec<&str> = pat.split('*').collect(); | ||||
|                         if pattern_parts.len() == 2 { | ||||
|                             let prefix = pattern_parts[0]; | ||||
|                             let suffix = pattern_parts[1]; | ||||
|                             key.starts_with(prefix) && key.ends_with(suffix) | ||||
|                         } else { | ||||
|                             key.contains(&pat.replace('*', "")) | ||||
|                         } | ||||
|                     } else { | ||||
|                         key.contains(pat) | ||||
|                     } | ||||
|                 } | ||||
|                 None => true, | ||||
|             }; | ||||
|              | ||||
|             if matches { | ||||
|                 keys.push(key); | ||||
|                 returned_keys += 1; | ||||
|                  | ||||
|                 // Stop if we've returned enough keys | ||||
|                 if returned_keys >= count { | ||||
|                     current_cursor += 1; | ||||
|                     break; | ||||
|                 } | ||||
|             } | ||||
|              | ||||
|             current_cursor += 1; | ||||
|         } | ||||
|          | ||||
|         // If we've reached the end of iteration, return cursor 0 to indicate completion | ||||
|         let next_cursor = if returned_keys < count { 0 } else { current_cursor }; | ||||
|          | ||||
|         Ok((next_cursor, keys)) | ||||
|     } | ||||
| } | ||||
|   | ||||
| @@ -230,6 +230,58 @@ test_expiration() { | ||||
|     redis_cmd "GET expire_ex_key" ""        # Should be expired | ||||
| } | ||||
|  | ||||
| # Function to test SCAN operations | ||||
| test_scan_operations() { | ||||
|     print_status "=== Testing SCAN Operations ===" | ||||
|      | ||||
|     # Set up test data for scanning | ||||
|     redis_cmd "SET scan_test1 value1" "OK" | ||||
|     redis_cmd "SET scan_test2 value2" "OK" | ||||
|     redis_cmd "SET scan_test3 value3" "OK" | ||||
|     redis_cmd "SET other_key other_value" "OK" | ||||
|     redis_cmd "HSET scan_hash field1 value1" "1" | ||||
|      | ||||
|     # Test basic SCAN | ||||
|     print_status "Testing basic SCAN with cursor 0" | ||||
|     redis_cmd "SCAN 0" "" | ||||
|      | ||||
|     # Test SCAN with MATCH pattern | ||||
|     print_status "Testing SCAN with MATCH pattern" | ||||
|     redis_cmd "SCAN 0 MATCH scan_test*" "" | ||||
|      | ||||
|     # Test SCAN with COUNT | ||||
|     print_status "Testing SCAN with COUNT 2" | ||||
|     redis_cmd "SCAN 0 COUNT 2" "" | ||||
|      | ||||
|     # Test SCAN with both MATCH and COUNT | ||||
|     print_status "Testing SCAN with MATCH and COUNT" | ||||
|     redis_cmd "SCAN 0 MATCH scan_* COUNT 1" "" | ||||
|      | ||||
|     # Test SCAN continuation with more keys | ||||
|     print_status "Setting up more keys for continuation test" | ||||
|     redis_cmd "SET scan_key1 val1" "OK" | ||||
|     redis_cmd "SET scan_key2 val2" "OK" | ||||
|     redis_cmd "SET scan_key3 val3" "OK" | ||||
|     redis_cmd "SET scan_key4 val4" "OK" | ||||
|     redis_cmd "SET scan_key5 val5" "OK" | ||||
|      | ||||
|     print_status "Testing SCAN with small COUNT for pagination" | ||||
|     redis_cmd "SCAN 0 COUNT 3" "" | ||||
|      | ||||
|     # Clean up SCAN test data | ||||
|     print_status "Cleaning up SCAN test data" | ||||
|     redis_cmd "DEL scan_test1" "1" | ||||
|     redis_cmd "DEL scan_test2" "1" | ||||
|     redis_cmd "DEL scan_test3" "1" | ||||
|     redis_cmd "DEL other_key" "1" | ||||
|     redis_cmd "DEL scan_hash" "1" | ||||
|     redis_cmd "DEL scan_key1" "1" | ||||
|     redis_cmd "DEL scan_key2" "1" | ||||
|     redis_cmd "DEL scan_key3" "1" | ||||
|     redis_cmd "DEL scan_key4" "1" | ||||
|     redis_cmd "DEL scan_key5" "1" | ||||
| } | ||||
|  | ||||
| # Main execution | ||||
| main() { | ||||
|     print_status "Starting HeroDB comprehensive test suite..." | ||||
| @@ -265,6 +317,7 @@ main() { | ||||
|     test_keys_operations || failed_tests=$((failed_tests + 1)) | ||||
|     test_info_operations || failed_tests=$((failed_tests + 1)) | ||||
|     test_expiration || failed_tests=$((failed_tests + 1)) | ||||
|     test_scan_operations || failed_tests=$((failed_tests + 1)) | ||||
|      | ||||
|     # Summary | ||||
|     echo | ||||
|   | ||||
		Reference in New Issue
	
	Block a user