WIP6: implementing image embedding as first step towards multi-model support

This commit is contained in:
Maxime Van Hees
2025-09-30 14:53:01 +02:00
parent 7d07b57d32
commit 2139deb85d
7 changed files with 838 additions and 46 deletions

View File

@@ -15,8 +15,11 @@ use crate::storage_trait::StorageBackend;
use crate::admin_meta;
// Embeddings: config and cache
use crate::embedding::{EmbeddingConfig, create_embedder, Embedder};
use crate::embedding::{EmbeddingConfig, create_embedder, Embedder, create_image_embedder, ImageEmbedder};
use serde_json;
use ureq::{Agent, AgentBuilder};
use std::time::Duration;
use std::io::Read;
#[derive(Clone)]
pub struct Server {
@@ -33,9 +36,12 @@ pub struct Server {
// Per-DB Lance stores (vector DB), keyed by db_id
pub lance_stores: Arc<std::sync::RwLock<HashMap<u64, Arc<crate::lance_store::LanceStore>>>>,
// Per-(db_id, dataset) embedder cache
// Per-(db_id, dataset) embedder cache (text)
pub embedders: Arc<std::sync::RwLock<HashMap<(u64, String), Arc<dyn Embedder>>>>,
// Per-(db_id, dataset) image embedder cache (image)
pub image_embedders: Arc<std::sync::RwLock<HashMap<(u64, String), Arc<dyn ImageEmbedder>>>>,
// BLPOP waiter registry: per (db_index, key) FIFO of waiters
pub list_waiters: Arc<Mutex<HashMap<u64, HashMap<String, Vec<Waiter>>>>>,
pub waiter_seq: Arc<AtomicU64>,
@@ -66,6 +72,7 @@ impl Server {
search_indexes: Arc::new(std::sync::RwLock::new(HashMap::new())),
lance_stores: Arc::new(std::sync::RwLock::new(HashMap::new())),
embedders: Arc::new(std::sync::RwLock::new(HashMap::new())),
image_embedders: Arc::new(std::sync::RwLock::new(HashMap::new())),
list_waiters: Arc::new(Mutex::new(HashMap::new())),
waiter_seq: Arc::new(AtomicU64::new(1)),
}
@@ -189,6 +196,10 @@ impl Server {
let mut map = self.embedders.write().unwrap();
map.remove(&(self.selected_db, dataset.to_string()));
}
{
let mut map_img = self.image_embedders.write().unwrap();
map_img.remove(&(self.selected_db, dataset.to_string()));
}
Ok(())
}
@@ -233,6 +244,88 @@ impl Server {
Ok(emb)
}
/// Resolve or build an IMAGE embedder for (db_id, dataset). Caches instance.
pub fn get_image_embedder_for(&self, dataset: &str) -> Result<Arc<dyn ImageEmbedder>, DBError> {
if self.selected_db == 0 {
return Err(DBError("Lance not available on admin DB 0".to_string()));
}
// Fast path
{
let map = self.image_embedders.read().unwrap();
if let Some(e) = map.get(&(self.selected_db, dataset.to_string())) {
return Ok(e.clone());
}
}
// Load config and instantiate
let cfg = self.get_dataset_embedding_config(dataset)?;
let emb = create_image_embedder(&cfg)?;
{
let mut map = self.image_embedders.write().unwrap();
map.insert((self.selected_db, dataset.to_string()), emb.clone());
}
Ok(emb)
}
/// Download image bytes from a URI with safety checks (size, timeout, content-type, optional host allowlist).
/// Env overrides:
/// - HERODB_IMAGE_MAX_BYTES (u64, default 10485760)
/// - HERODB_IMAGE_FETCH_TIMEOUT_SECS (u64, default 30)
/// - HERODB_IMAGE_ALLOWED_HOSTS (comma-separated, optional)
pub fn fetch_image_bytes_from_uri(&self, uri: &str) -> Result<Vec<u8>, DBError> {
// Basic scheme validation
if !(uri.starts_with("http://") || uri.starts_with("https://")) {
return Err(DBError("Only http(s) URIs are supported for image fetch".into()));
}
// Parse host (naive) for allowlist check
let host = {
let after_scheme = match uri.find("://") {
Some(i) => &uri[i + 3..],
None => uri,
};
let end = after_scheme.find('/').unwrap_or(after_scheme.len());
let host_port = &after_scheme[..end];
host_port.split('@').last().unwrap_or(host_port).split(':').next().unwrap_or(host_port).to_string()
};
let max_bytes: u64 = std::env::var("HERODB_IMAGE_MAX_BYTES").ok().and_then(|s| s.parse::<u64>().ok()).unwrap_or(10 * 1024 * 1024);
let timeout_secs: u64 = std::env::var("HERODB_IMAGE_FETCH_TIMEOUT_SECS").ok().and_then(|s| s.parse::<u64>().ok()).unwrap_or(30);
let allowed_hosts_env = std::env::var("HERODB_IMAGE_ALLOWED_HOSTS").ok();
if let Some(allow) = allowed_hosts_env {
if !allow.split(',').map(|s| s.trim()).filter(|s| !s.is_empty()).any(|h| h.eq_ignore_ascii_case(&host)) {
return Err(DBError(format!("Host '{}' not allowed for image fetch (HERODB_IMAGE_ALLOWED_HOSTS)", host)));
}
}
let agent: Agent = AgentBuilder::new()
.timeout_read(Duration::from_secs(timeout_secs))
.timeout_write(Duration::from_secs(timeout_secs))
.build();
let resp = agent.get(uri).call().map_err(|e| DBError(format!("HTTP GET failed: {}", e)))?;
// Validate content-type
let ctype = resp.header("Content-Type").unwrap_or("");
let ctype_main = ctype.split(';').next().unwrap_or("").trim().to_ascii_lowercase();
if !ctype_main.starts_with("image/") {
return Err(DBError(format!("Remote content-type '{}' is not image/*", ctype)));
}
// Read with cap
let mut reader = resp.into_reader();
let mut buf: Vec<u8> = Vec::with_capacity(8192);
let mut tmp = [0u8; 8192];
let mut total: u64 = 0;
loop {
let n = reader.read(&mut tmp).map_err(|e| DBError(format!("Read error: {}", e)))?;
if n == 0 { break; }
total += n as u64;
if total > max_bytes {
return Err(DBError(format!("Image exceeds max allowed bytes {}", max_bytes)));
}
buf.extend_from_slice(&tmp[..n]);
}
Ok(buf)
}
/// Check if current permissions allow read operations
pub fn has_read_permission(&self) -> bool {
// If an explicit permission is set for this connection, honor it.