it works
This commit is contained in:
		
							
								
								
									
										36
									
								
								src/cmd.rs
									
									
									
									
									
								
							
							
						
						
									
										36
									
								
								src/cmd.rs
									
									
									
									
									
								
							| @@ -484,10 +484,7 @@ impl Cmd { | ||||
|             Cmd::LIndex(key, index) => lindex_cmd(server, &key, index).await, | ||||
|             Cmd::LRange(key, start, stop) => lrange_cmd(server, &key, start, stop).await, | ||||
|             Cmd::FlushDb => flushdb_cmd(server).await, | ||||
|             Cmd::Unknow(s) => { | ||||
|                 println!("\x1b[31;1munknown command: {}\x1b[0m", s); | ||||
|                 Ok(Protocol::err(&format!("ERR unknown command '{}'", s))) | ||||
|             } | ||||
|             Cmd::Unknow(s) => Ok(Protocol::err(&format!("ERR unknown command `{}`", s))), | ||||
|         } | ||||
|     } | ||||
|      | ||||
| @@ -645,24 +642,21 @@ async fn incr_cmd(server: &Server, key: &String) -> Result<Protocol, DBError> { | ||||
| } | ||||
|  | ||||
| fn config_get_cmd(name: &String, server: &Server) -> Result<Protocol, DBError> { | ||||
|     let mut result = Vec::new(); | ||||
|     result.push(Protocol::BulkString(name.clone())); | ||||
|     let value = match name.as_str() { | ||||
|         "dir" => Some(server.option.dir.clone()), | ||||
|         "dbfilename" => Some(format!("{}.db", server.selected_db)), | ||||
|         "databases" => Some("16".to_string()), // Hardcoded as per original logic | ||||
|         _ => None, | ||||
|     }; | ||||
|  | ||||
|     match name.as_str() { | ||||
|         "dir" => { | ||||
|             result.push(Protocol::BulkString(server.option.dir.clone())); | ||||
|             Ok(Protocol::Array(result)) | ||||
|         } | ||||
|         "dbfilename" => { | ||||
|             result.push(Protocol::BulkString(format!("{}.db", server.selected_db))); | ||||
|             Ok(Protocol::Array(result)) | ||||
|         }, | ||||
|         "databases" => { | ||||
|             // This is hardcoded, as the feature was removed | ||||
|             result.push(Protocol::BulkString("16".to_string())); | ||||
|             Ok(Protocol::Array(result)) | ||||
|         }, | ||||
|         _ => Ok(Protocol::Array(vec![])), | ||||
|     if let Some(val) = value { | ||||
|         Ok(Protocol::Array(vec![ | ||||
|             Protocol::BulkString(name.clone()), | ||||
|             Protocol::BulkString(val), | ||||
|         ])) | ||||
|     } else { | ||||
|         // Return an empty array for unknown config options, which is standard Redis behavior | ||||
|         Ok(Protocol::Array(vec![])) | ||||
|     } | ||||
| } | ||||
|  | ||||
|   | ||||
| @@ -42,6 +42,13 @@ impl Server { | ||||
|         let db_file_path = std::path::PathBuf::from(self.option.dir.clone()) | ||||
|             .join(format!("{}.db", self.selected_db)); | ||||
|          | ||||
|         // Ensure the directory exists before creating the database file | ||||
|         if let Some(parent_dir) = db_file_path.parent() { | ||||
|             std::fs::create_dir_all(parent_dir).map_err(|e| { | ||||
|                 DBError(format!("Failed to create directory {}: {}", parent_dir.display(), e)) | ||||
|             })?; | ||||
|         } | ||||
|          | ||||
|         println!("Creating new db file: {}", db_file_path.display()); | ||||
|          | ||||
|         let storage = Arc::new(Storage::new( | ||||
|   | ||||
							
								
								
									
										193
									
								
								src/storage.rs
									
									
									
									
									
								
							
							
						
						
									
										193
									
								
								src/storage.rs
									
									
									
									
									
								
							| @@ -225,12 +225,15 @@ impl Storage { | ||||
|             for key in keys { | ||||
|                 strings_table.remove(key.as_str())?; | ||||
|             } | ||||
|             let keys: Vec<(String,String)> = hashes_table.iter()?.map(|item| { | ||||
|                 let binding = item.unwrap(); | ||||
|                 let (key, field) = binding.0.value(); | ||||
|                 (key.to_string(), field.to_string()) | ||||
|             }).collect(); | ||||
|             for (key,field) in keys { | ||||
|             let keys: Vec<(String, String)> = hashes_table | ||||
|                 .iter()? | ||||
|                 .map(|item| { | ||||
|                     let binding = item.unwrap(); | ||||
|                     let (k, f) = binding.0.value(); | ||||
|                     (k.to_string(), f.to_string()) | ||||
|                 }) | ||||
|                 .collect(); | ||||
|             for (key, field) in keys { | ||||
|                 hashes_table.remove((key.as_str(), field.as_str()))?; | ||||
|             } | ||||
|             let keys: Vec<String> = lists_table.iter()?.map(|item| item.unwrap().0.value().to_string()).collect(); | ||||
| @@ -262,9 +265,20 @@ impl Storage { | ||||
|         let read_txn = self.db.begin_read()?; | ||||
|         let table = read_txn.open_table(TYPES_TABLE)?; | ||||
|          | ||||
|         match table.get(key)? { | ||||
|             Some(type_val) => Ok(Some(type_val.value().to_string())), | ||||
|             None => Ok(None), | ||||
|         // Before returning type, check for expiration | ||||
|         if let Some(type_val) = table.get(key)? { | ||||
|             if type_val.value() == "string" { | ||||
|                 let expiration_table = read_txn.open_table(EXPIRATION_TABLE)?; | ||||
|                 if let Some(expires_at) = expiration_table.get(key)? { | ||||
|                     if now_in_millis() > expires_at.value() as u128 { | ||||
|                         // The key is expired, so it effectively has no type | ||||
|                         return Ok(None); | ||||
|                     } | ||||
|                 } | ||||
|             } | ||||
|             Ok(Some(type_val.value().to_string())) | ||||
|         } else { | ||||
|             Ok(None) | ||||
|         } | ||||
|     } | ||||
|  | ||||
| @@ -353,7 +367,7 @@ impl Storage { | ||||
|         { | ||||
|             let mut types_table = write_txn.open_table(TYPES_TABLE)?; | ||||
|             let mut strings_table = write_txn.open_table(STRINGS_TABLE)?; | ||||
|             let mut hashes_table = write_txn.open_table(HASHES_TABLE)?; | ||||
|             let mut hashes_table: redb::Table<(&str, &str), &[u8]> = write_txn.open_table(HASHES_TABLE)?; | ||||
|             let mut lists_table = write_txn.open_table(LISTS_TABLE)?; | ||||
|              | ||||
|             // Remove from type table | ||||
| @@ -501,37 +515,32 @@ impl Storage { | ||||
|     } | ||||
|  | ||||
|     pub fn hdel(&self, key: &str, fields: &[String]) -> Result<u64, DBError> { | ||||
|         let write_txn = self.db.begin_write()?; | ||||
|         let mut deleted = 0u64; | ||||
|          | ||||
|         { | ||||
|             let types_table = write_txn.open_table(TYPES_TABLE)?; | ||||
|             let key_type = types_table.get(key)?; | ||||
|             match key_type { | ||||
|                 Some(type_val) if type_val.value() == "hash" => { | ||||
|         // Enforce type check before proceeding to write transaction | ||||
|         let key_type = self.get_key_type(key)?; | ||||
|         match key_type.as_deref() { | ||||
|             Some("hash") => { | ||||
|                 let write_txn = self.db.begin_write()?; | ||||
|                 let mut deleted = 0u64; | ||||
|                 { | ||||
|                     let mut hashes_table = write_txn.open_table(HASHES_TABLE)?; | ||||
|                      | ||||
|                     for field in fields { | ||||
|                         if hashes_table.remove((key, field.as_str()))?.is_some() { | ||||
|                             deleted += 1; | ||||
|                         } | ||||
|                     } | ||||
|                 } | ||||
|                 Some(_) => return Err(DBError("WRONGTYPE Operation against a key holding the wrong kind of value".to_string())), | ||||
|                 None => {} | ||||
|                 write_txn.commit()?; | ||||
|                 Ok(deleted) | ||||
|             } | ||||
|             Some(_) => Err(DBError("WRONGTYPE Operation against a key holding the wrong kind of value".to_string())), | ||||
|             None => Ok(0), // Key doesn't exist, so 0 fields deleted. | ||||
|         } | ||||
|          | ||||
|         write_txn.commit()?; | ||||
|         Ok(deleted) | ||||
|     } | ||||
|  | ||||
|     pub fn hexists(&self, key: &str, field: &str) -> Result<bool, DBError> { | ||||
|         let read_txn = self.db.begin_read()?; | ||||
|          | ||||
|         let types_table = read_txn.open_table(TYPES_TABLE)?; | ||||
|         match types_table.get(key)? { | ||||
|             Some(type_val) if type_val.value() == "hash" => { | ||||
|         match self.get_key_type(key)?.as_deref() { | ||||
|             Some("hash") => { | ||||
|                 let read_txn = self.db.begin_read()?; | ||||
|                 let hashes_table = read_txn.open_table(HASHES_TABLE)?; | ||||
|                 Ok(hashes_table.get((key, field))?.is_some()) | ||||
|             } | ||||
| @@ -541,23 +550,14 @@ impl Storage { | ||||
|     } | ||||
|  | ||||
|     pub fn hkeys(&self, key: &str) -> Result<Vec<String>, DBError> { | ||||
|         let read_txn = self.db.begin_read()?; | ||||
|          | ||||
|         let types_table = read_txn.open_table(TYPES_TABLE)?; | ||||
|         match types_table.get(key)? { | ||||
|             Some(type_val) if type_val.value() == "hash" => { | ||||
|         match self.get_key_type(key)?.as_deref() { | ||||
|             Some("hash") => { | ||||
|                 let read_txn = self.db.begin_read()?; | ||||
|                 let hashes_table = read_txn.open_table(HASHES_TABLE)?; | ||||
|                 let mut result = Vec::new(); | ||||
|                  | ||||
|                 let mut iter = hashes_table.iter()?; | ||||
|                 while let Some(entry) = iter.next() { | ||||
|                     let entry = entry?; | ||||
|                     let (hash_key, field) = entry.0.value(); | ||||
|                     if hash_key == key { | ||||
|                         result.push(field.to_string()); | ||||
|                     } | ||||
|                 for entry in hashes_table.range((key, "")..=(key, "\u{FFFF}"))? { | ||||
|                     result.push(entry?.0.value().1.to_string()); | ||||
|                 } | ||||
|                  | ||||
|                 Ok(result) | ||||
|             } | ||||
|             Some(_) => Err(DBError("WRONGTYPE Operation against a key holding the wrong kind of value".to_string())), | ||||
| @@ -566,26 +566,15 @@ impl Storage { | ||||
|     } | ||||
|  | ||||
|     pub fn hvals(&self, key: &str) -> Result<Vec<String>, DBError> { | ||||
|         let read_txn = self.db.begin_read()?; | ||||
|          | ||||
|         let types_table = read_txn.open_table(TYPES_TABLE)?; | ||||
|         match types_table.get(key)? { | ||||
|             Some(type_val) if type_val.value() == "hash" => { | ||||
|         match self.get_key_type(key)?.as_deref() { | ||||
|             Some("hash") => { | ||||
|                 let read_txn = self.db.begin_read()?; | ||||
|                 let hashes_table = read_txn.open_table(HASHES_TABLE)?; | ||||
|                 let mut result = Vec::new(); | ||||
|                  | ||||
|                 let mut iter = hashes_table.iter()?; | ||||
|                 while let Some(entry) = iter.next() { | ||||
|                     let entry = entry?; | ||||
|                     let (hash_key, _) = entry.0.value(); | ||||
|                     let value = entry.1.value(); | ||||
|                     if hash_key == key { | ||||
|                         let decrypted = self.decrypt_if_needed(value)?; | ||||
|                         let value_str = String::from_utf8(decrypted)?; | ||||
|                         result.push(value_str); | ||||
|                     } | ||||
|                 for entry in hashes_table.range((key, "")..=(key, "\u{FFFF}"))? { | ||||
|                     let value = self.decrypt_if_needed(entry?.1.value())?; | ||||
|                     result.push(String::from_utf8(value)?); | ||||
|                 } | ||||
|                  | ||||
|                 Ok(result) | ||||
|             } | ||||
|             Some(_) => Err(DBError("WRONGTYPE Operation against a key holding the wrong kind of value".to_string())), | ||||
| @@ -594,24 +583,12 @@ impl Storage { | ||||
|     } | ||||
|  | ||||
|     pub fn hlen(&self, key: &str) -> Result<u64, DBError> { | ||||
|         let read_txn = self.db.begin_read()?; | ||||
|          | ||||
|         let types_table = read_txn.open_table(TYPES_TABLE)?; | ||||
|         match types_table.get(key)? { | ||||
|             Some(type_val) if type_val.value() == "hash" => { | ||||
|         match self.get_key_type(key)?.as_deref() { | ||||
|             Some("hash") => { | ||||
|                 let read_txn = self.db.begin_read()?; | ||||
|                 let hashes_table = read_txn.open_table(HASHES_TABLE)?; | ||||
|                 let mut count = 0u64; | ||||
|                  | ||||
|                 let mut iter = hashes_table.iter()?; | ||||
|                 while let Some(entry) = iter.next() { | ||||
|                     let entry = entry?; | ||||
|                     let (hash_key, _) = entry.0.value(); | ||||
|                     if hash_key == key { | ||||
|                         count += 1; | ||||
|                     } | ||||
|                 } | ||||
|                  | ||||
|                 Ok(count) | ||||
|                 // Use `range` for efficiency | ||||
|                 Ok(hashes_table.range((key, "")..=(key, "\u{FFFF}"))?.count() as u64) | ||||
|             } | ||||
|             Some(_) => Err(DBError("WRONGTYPE Operation against a key holding the wrong kind of value".to_string())), | ||||
|             None => Ok(0), | ||||
| @@ -619,29 +596,22 @@ impl Storage { | ||||
|     } | ||||
|  | ||||
|     pub fn hmget(&self, key: &str, fields: &[String]) -> Result<Vec<Option<String>>, DBError> { | ||||
|         let read_txn = self.db.begin_read()?; | ||||
|          | ||||
|         let types_table = read_txn.open_table(TYPES_TABLE)?; | ||||
|         match types_table.get(key)? { | ||||
|             Some(type_val) if type_val.value() == "hash" => { | ||||
|         match self.get_key_type(key)?.as_deref() { | ||||
|             Some("hash") => { | ||||
|                 let read_txn = self.db.begin_read()?; | ||||
|                 let hashes_table = read_txn.open_table(HASHES_TABLE)?; | ||||
|                 let mut result = Vec::new(); | ||||
|                  | ||||
|                 for field in fields { | ||||
|                     match hashes_table.get((key, field.as_str()))? { | ||||
|                         Some(data) => { | ||||
|                             let decrypted = self.decrypt_if_needed(data.value())?; | ||||
|                             let value = String::from_utf8(decrypted)?; | ||||
|                             result.push(Some(value)); | ||||
|                         } | ||||
|                         None => result.push(None), | ||||
|                     } | ||||
|                     let value = match hashes_table.get((key, field.as_str()))? { | ||||
|                         Some(data) => Some(String::from_utf8(self.decrypt_if_needed(data.value())?)?), | ||||
|                         None => None, | ||||
|                     }; | ||||
|                     result.push(value); | ||||
|                 } | ||||
|                  | ||||
|                 Ok(result) | ||||
|             } | ||||
|             Some(_) => Err(DBError("WRONGTYPE Operation against a key holding the wrong kind of value".to_string())), | ||||
|             None => Ok(fields.iter().map(|_| None).collect()), | ||||
|             None => Ok(vec![None; fields.len()]), | ||||
|         } | ||||
|     } | ||||
|  | ||||
| @@ -684,16 +654,19 @@ impl Storage { | ||||
|  | ||||
|     pub fn scan(&self, cursor: u64, pattern: Option<&str>, count: Option<u64>) -> Result<(u64, Vec<String>), DBError> { | ||||
|         let read_txn = self.db.begin_read()?; | ||||
|         let table = read_txn.open_table(TYPES_TABLE)?; | ||||
|          | ||||
|         // Explicitly specify the table type to avoid confusion | ||||
|         let types_table: redb::ReadOnlyTable<&str, &str> = read_txn.open_table(TYPES_TABLE)?; | ||||
|          | ||||
|         let count = count.unwrap_or(10); // Default count is 10 | ||||
|         let mut keys = Vec::new(); | ||||
|         let mut current_cursor = 0u64; | ||||
|         let mut returned_keys = 0u64; | ||||
|          | ||||
|         let mut iter = table.iter()?; | ||||
|         let mut iter = types_table.iter()?; | ||||
|         while let Some(entry) = iter.next() { | ||||
|             let key = entry?.0.value().to_string(); | ||||
|             let entry = entry?; | ||||
|             let key = entry.0.value().to_string(); | ||||
|              | ||||
|             // Skip keys until we reach the cursor position | ||||
|             if current_cursor < cursor { | ||||
| @@ -707,15 +680,8 @@ impl Storage { | ||||
|                     if pat == "*" { | ||||
|                         true | ||||
|                     } else if pat.contains('*') { | ||||
|                         // Simple glob pattern matching | ||||
|                         let pattern_parts: Vec<&str> = pat.split('*').collect(); | ||||
|                         if pattern_parts.len() == 2 { | ||||
|                             let prefix = pattern_parts[0]; | ||||
|                             let suffix = pattern_parts[1]; | ||||
|                             key.starts_with(prefix) && key.ends_with(suffix) | ||||
|                         } else { | ||||
|                             key.contains(&pat.replace('*', "")) | ||||
|                         } | ||||
|                         // Use the glob_match function for better pattern matching | ||||
|                         glob_match(pat, &key) | ||||
|                     } else { | ||||
|                         key.contains(pat) | ||||
|                     } | ||||
| @@ -746,10 +712,10 @@ impl Storage { | ||||
|         let read_txn = self.db.begin_read()?; | ||||
|          | ||||
|         // Check if key exists and is a hash | ||||
|         let types_table = read_txn.open_table(TYPES_TABLE)?; | ||||
|         let types_table: redb::ReadOnlyTable<&str, &str> = read_txn.open_table(TYPES_TABLE)?; | ||||
|         match types_table.get(key)? { | ||||
|             Some(type_val) if type_val.value() == "hash" => { | ||||
|                 let hashes_table = read_txn.open_table(HASHES_TABLE)?; | ||||
|                 let hashes_table: redb::ReadOnlyTable<(&str, &str), &[u8]> = read_txn.open_table(HASHES_TABLE)?; | ||||
|                 let count = count.unwrap_or(10); | ||||
|                 let mut fields = Vec::new(); | ||||
|                 let mut current_cursor = 0u64; | ||||
| @@ -777,14 +743,8 @@ impl Storage { | ||||
|                             if pat == "*" { | ||||
|                                 true | ||||
|                             } else if pat.contains('*') { | ||||
|                                 let pattern_parts: Vec<&str> = pat.split('*').collect(); | ||||
|                                 if pattern_parts.len() == 2 { | ||||
|                                     let prefix = pattern_parts[0]; | ||||
|                                     let suffix = pattern_parts[1]; | ||||
|                                     field.starts_with(prefix) && field.ends_with(suffix) | ||||
|                                 } else { | ||||
|                                     field.contains(&pat.replace('*', "")) | ||||
|                                 } | ||||
|                                 // Use the glob_match function for better pattern matching | ||||
|                                 glob_match(pat, field) | ||||
|                             } else { | ||||
|                                 field.contains(pat) | ||||
|                             } | ||||
| @@ -807,7 +767,8 @@ impl Storage { | ||||
|                     current_cursor += 1; | ||||
|                 } | ||||
|                  | ||||
|                 let next_cursor = if iter.next().is_none() { 0 } else { current_cursor }; | ||||
|                 // Check if there are more entries by trying to get the next one | ||||
|                 let next_cursor = if returned_fields < count { 0 } else { current_cursor }; | ||||
|                 Ok((next_cursor, fields)) | ||||
|             } | ||||
|             Some(_) => Err(DBError("WRONGTYPE Operation against a key holding the wrong kind of value".to_string())), | ||||
|   | ||||
		Reference in New Issue
	
	Block a user