151 lines
4.8 KiB
Rust
151 lines
4.8 KiB
Rust
use crate::error::Result;
|
|
use crate::index::FieldIndex;
|
|
use crate::retrieve::query::{RetrievalQuery, SearchResult};
|
|
use crate::store::{HeroDbClient, OsirisObject};
|
|
|
|
/// Search engine for OSIRIS
|
|
pub struct SearchEngine {
|
|
client: HeroDbClient,
|
|
index: FieldIndex,
|
|
}
|
|
|
|
impl SearchEngine {
|
|
/// Create a new search engine
|
|
pub fn new(client: HeroDbClient) -> Self {
|
|
let index = FieldIndex::new(client.clone());
|
|
Self { client, index }
|
|
}
|
|
|
|
/// Execute a search query
|
|
pub async fn search(&self, query: &RetrievalQuery) -> Result<Vec<SearchResult>> {
|
|
// Step 1: Get candidate IDs from field filters
|
|
let candidate_ids = if query.filters.is_empty() {
|
|
self.index.get_all_ids().await?
|
|
} else {
|
|
self.index.get_ids_by_filters(&query.filters).await?
|
|
};
|
|
|
|
// Step 2: If text query is provided, filter by substring match
|
|
let mut results = Vec::new();
|
|
|
|
if let Some(text_query) = &query.text {
|
|
let text_query_lower = text_query.to_lowercase();
|
|
|
|
for id in candidate_ids {
|
|
// Fetch the object
|
|
if let Ok(obj) = self.client.get_object(&id).await {
|
|
// Check if text matches
|
|
let score = self.compute_text_score(&obj, &text_query_lower);
|
|
|
|
if score > 0.0 {
|
|
let snippet = self.extract_snippet(&obj, &text_query_lower);
|
|
results.push(SearchResult::new(id, score).with_snippet(snippet));
|
|
}
|
|
}
|
|
}
|
|
} else {
|
|
// No text query, return all candidates with score 1.0
|
|
for id in candidate_ids {
|
|
results.push(SearchResult::new(id, 1.0));
|
|
}
|
|
}
|
|
|
|
// Step 3: Sort by score (descending) and limit
|
|
results.sort_by(|a, b| b.score.partial_cmp(&a.score).unwrap());
|
|
results.truncate(query.top_k);
|
|
|
|
Ok(results)
|
|
}
|
|
|
|
/// Compute text match score (simple substring matching)
|
|
fn compute_text_score(&self, obj: &OsirisObject, query: &str) -> f32 {
|
|
let mut score = 0.0;
|
|
|
|
// Check title
|
|
if let Some(title) = &obj.meta.title {
|
|
if title.to_lowercase().contains(query) {
|
|
score += 0.5;
|
|
}
|
|
}
|
|
|
|
// Check text content
|
|
if let Some(text) = &obj.text {
|
|
if text.to_lowercase().contains(query) {
|
|
score += 0.5;
|
|
|
|
// Bonus for multiple occurrences
|
|
let count = text.to_lowercase().matches(query).count();
|
|
score += (count as f32 - 1.0) * 0.1;
|
|
}
|
|
}
|
|
|
|
// Check tags
|
|
for (key, value) in &obj.meta.tags {
|
|
if key.to_lowercase().contains(query) || value.to_lowercase().contains(query) {
|
|
score += 0.2;
|
|
}
|
|
}
|
|
|
|
score.min(1.0)
|
|
}
|
|
|
|
/// Extract a snippet around the matched text
|
|
fn extract_snippet(&self, obj: &OsirisObject, query: &str) -> String {
|
|
const SNIPPET_LENGTH: usize = 100;
|
|
|
|
// Try to find snippet in text
|
|
if let Some(text) = &obj.text {
|
|
let text_lower = text.to_lowercase();
|
|
if let Some(pos) = text_lower.find(query) {
|
|
let start = pos.saturating_sub(SNIPPET_LENGTH / 2);
|
|
let end = (pos + query.len() + SNIPPET_LENGTH / 2).min(text.len());
|
|
|
|
let mut snippet = text[start..end].to_string();
|
|
if start > 0 {
|
|
snippet = format!("...{}", snippet);
|
|
}
|
|
if end < text.len() {
|
|
snippet = format!("{}...", snippet);
|
|
}
|
|
|
|
return snippet;
|
|
}
|
|
}
|
|
|
|
// Fallback to title or first N chars
|
|
if let Some(title) = &obj.meta.title {
|
|
return title.clone();
|
|
}
|
|
|
|
if let Some(text) = &obj.text {
|
|
let end = SNIPPET_LENGTH.min(text.len());
|
|
let mut snippet = text[..end].to_string();
|
|
if end < text.len() {
|
|
snippet = format!("{}...", snippet);
|
|
}
|
|
return snippet;
|
|
}
|
|
|
|
String::from("[No content]")
|
|
}
|
|
}
|
|
|
|
#[cfg(test)]
|
|
mod tests {
|
|
use super::*;
|
|
|
|
#[tokio::test]
|
|
#[ignore]
|
|
async fn test_search() {
|
|
let client = HeroDbClient::new("redis://localhost:6379", 1).unwrap();
|
|
let engine = SearchEngine::new(client);
|
|
|
|
let query = RetrievalQuery::new("test".to_string())
|
|
.with_text("rust".to_string())
|
|
.with_top_k(10);
|
|
|
|
let results = engine.search(&query).await.unwrap();
|
|
assert!(results.len() <= 10);
|
|
}
|
|
}
|