Files
horus/lib/osiris/core/retrieve/search.rs
2025-11-13 20:44:00 +01:00

151 lines
4.8 KiB
Rust

use crate::error::Result;
use crate::index::FieldIndex;
use crate::retrieve::query::{RetrievalQuery, SearchResult};
use crate::store::{HeroDbClient, OsirisObject};
/// Search engine for OSIRIS
pub struct SearchEngine {
client: HeroDbClient,
index: FieldIndex,
}
impl SearchEngine {
/// Create a new search engine
pub fn new(client: HeroDbClient) -> Self {
let index = FieldIndex::new(client.clone());
Self { client, index }
}
/// Execute a search query
pub async fn search(&self, query: &RetrievalQuery) -> Result<Vec<SearchResult>> {
// Step 1: Get candidate IDs from field filters
let candidate_ids = if query.filters.is_empty() {
self.index.get_all_ids().await?
} else {
self.index.get_ids_by_filters(&query.filters).await?
};
// Step 2: If text query is provided, filter by substring match
let mut results = Vec::new();
if let Some(text_query) = &query.text {
let text_query_lower = text_query.to_lowercase();
for id in candidate_ids {
// Fetch the object
if let Ok(obj) = self.client.get_object(&id).await {
// Check if text matches
let score = self.compute_text_score(&obj, &text_query_lower);
if score > 0.0 {
let snippet = self.extract_snippet(&obj, &text_query_lower);
results.push(SearchResult::new(id, score).with_snippet(snippet));
}
}
}
} else {
// No text query, return all candidates with score 1.0
for id in candidate_ids {
results.push(SearchResult::new(id, 1.0));
}
}
// Step 3: Sort by score (descending) and limit
results.sort_by(|a, b| b.score.partial_cmp(&a.score).unwrap());
results.truncate(query.top_k);
Ok(results)
}
/// Compute text match score (simple substring matching)
fn compute_text_score(&self, obj: &OsirisObject, query: &str) -> f32 {
let mut score = 0.0;
// Check title
if let Some(title) = &obj.meta.title {
if title.to_lowercase().contains(query) {
score += 0.5;
}
}
// Check text content
if let Some(text) = &obj.text {
if text.to_lowercase().contains(query) {
score += 0.5;
// Bonus for multiple occurrences
let count = text.to_lowercase().matches(query).count();
score += (count as f32 - 1.0) * 0.1;
}
}
// Check tags
for (key, value) in &obj.meta.tags {
if key.to_lowercase().contains(query) || value.to_lowercase().contains(query) {
score += 0.2;
}
}
score.min(1.0)
}
/// Extract a snippet around the matched text
fn extract_snippet(&self, obj: &OsirisObject, query: &str) -> String {
const SNIPPET_LENGTH: usize = 100;
// Try to find snippet in text
if let Some(text) = &obj.text {
let text_lower = text.to_lowercase();
if let Some(pos) = text_lower.find(query) {
let start = pos.saturating_sub(SNIPPET_LENGTH / 2);
let end = (pos + query.len() + SNIPPET_LENGTH / 2).min(text.len());
let mut snippet = text[start..end].to_string();
if start > 0 {
snippet = format!("...{}", snippet);
}
if end < text.len() {
snippet = format!("{}...", snippet);
}
return snippet;
}
}
// Fallback to title or first N chars
if let Some(title) = &obj.meta.title {
return title.clone();
}
if let Some(text) = &obj.text {
let end = SNIPPET_LENGTH.min(text.len());
let mut snippet = text[..end].to_string();
if end < text.len() {
snippet = format!("{}...", snippet);
}
return snippet;
}
String::from("[No content]")
}
}
#[cfg(test)]
mod tests {
use super::*;
#[tokio::test]
#[ignore]
async fn test_search() {
let client = HeroDbClient::new("redis://localhost:6379", 1).unwrap();
let engine = SearchEngine::new(client);
let query = RetrievalQuery::new("test".to_string())
.with_text("rust".to_string())
.with_top_k(10);
let results = engine.search(&query).await.unwrap();
assert!(results.len() <= 10);
}
}