WIP6: implementing image embedding as first step towards multi-model support
This commit is contained in:
@@ -15,8 +15,11 @@ use crate::storage_trait::StorageBackend;
|
||||
use crate::admin_meta;
|
||||
|
||||
// Embeddings: config and cache
|
||||
use crate::embedding::{EmbeddingConfig, create_embedder, Embedder};
|
||||
use crate::embedding::{EmbeddingConfig, create_embedder, Embedder, create_image_embedder, ImageEmbedder};
|
||||
use serde_json;
|
||||
use ureq::{Agent, AgentBuilder};
|
||||
use std::time::Duration;
|
||||
use std::io::Read;
|
||||
|
||||
#[derive(Clone)]
|
||||
pub struct Server {
|
||||
@@ -33,9 +36,12 @@ pub struct Server {
|
||||
// Per-DB Lance stores (vector DB), keyed by db_id
|
||||
pub lance_stores: Arc<std::sync::RwLock<HashMap<u64, Arc<crate::lance_store::LanceStore>>>>,
|
||||
|
||||
// Per-(db_id, dataset) embedder cache
|
||||
// Per-(db_id, dataset) embedder cache (text)
|
||||
pub embedders: Arc<std::sync::RwLock<HashMap<(u64, String), Arc<dyn Embedder>>>>,
|
||||
|
||||
// Per-(db_id, dataset) image embedder cache (image)
|
||||
pub image_embedders: Arc<std::sync::RwLock<HashMap<(u64, String), Arc<dyn ImageEmbedder>>>>,
|
||||
|
||||
// BLPOP waiter registry: per (db_index, key) FIFO of waiters
|
||||
pub list_waiters: Arc<Mutex<HashMap<u64, HashMap<String, Vec<Waiter>>>>>,
|
||||
pub waiter_seq: Arc<AtomicU64>,
|
||||
@@ -66,6 +72,7 @@ impl Server {
|
||||
search_indexes: Arc::new(std::sync::RwLock::new(HashMap::new())),
|
||||
lance_stores: Arc::new(std::sync::RwLock::new(HashMap::new())),
|
||||
embedders: Arc::new(std::sync::RwLock::new(HashMap::new())),
|
||||
image_embedders: Arc::new(std::sync::RwLock::new(HashMap::new())),
|
||||
list_waiters: Arc::new(Mutex::new(HashMap::new())),
|
||||
waiter_seq: Arc::new(AtomicU64::new(1)),
|
||||
}
|
||||
@@ -189,6 +196,10 @@ impl Server {
|
||||
let mut map = self.embedders.write().unwrap();
|
||||
map.remove(&(self.selected_db, dataset.to_string()));
|
||||
}
|
||||
{
|
||||
let mut map_img = self.image_embedders.write().unwrap();
|
||||
map_img.remove(&(self.selected_db, dataset.to_string()));
|
||||
}
|
||||
Ok(())
|
||||
}
|
||||
|
||||
@@ -233,6 +244,88 @@ impl Server {
|
||||
Ok(emb)
|
||||
}
|
||||
|
||||
/// Resolve or build an IMAGE embedder for (db_id, dataset). Caches instance.
|
||||
pub fn get_image_embedder_for(&self, dataset: &str) -> Result<Arc<dyn ImageEmbedder>, DBError> {
|
||||
if self.selected_db == 0 {
|
||||
return Err(DBError("Lance not available on admin DB 0".to_string()));
|
||||
}
|
||||
// Fast path
|
||||
{
|
||||
let map = self.image_embedders.read().unwrap();
|
||||
if let Some(e) = map.get(&(self.selected_db, dataset.to_string())) {
|
||||
return Ok(e.clone());
|
||||
}
|
||||
}
|
||||
// Load config and instantiate
|
||||
let cfg = self.get_dataset_embedding_config(dataset)?;
|
||||
let emb = create_image_embedder(&cfg)?;
|
||||
{
|
||||
let mut map = self.image_embedders.write().unwrap();
|
||||
map.insert((self.selected_db, dataset.to_string()), emb.clone());
|
||||
}
|
||||
Ok(emb)
|
||||
}
|
||||
|
||||
/// Download image bytes from a URI with safety checks (size, timeout, content-type, optional host allowlist).
|
||||
/// Env overrides:
|
||||
/// - HERODB_IMAGE_MAX_BYTES (u64, default 10485760)
|
||||
/// - HERODB_IMAGE_FETCH_TIMEOUT_SECS (u64, default 30)
|
||||
/// - HERODB_IMAGE_ALLOWED_HOSTS (comma-separated, optional)
|
||||
pub fn fetch_image_bytes_from_uri(&self, uri: &str) -> Result<Vec<u8>, DBError> {
|
||||
// Basic scheme validation
|
||||
if !(uri.starts_with("http://") || uri.starts_with("https://")) {
|
||||
return Err(DBError("Only http(s) URIs are supported for image fetch".into()));
|
||||
}
|
||||
// Parse host (naive) for allowlist check
|
||||
let host = {
|
||||
let after_scheme = match uri.find("://") {
|
||||
Some(i) => &uri[i + 3..],
|
||||
None => uri,
|
||||
};
|
||||
let end = after_scheme.find('/').unwrap_or(after_scheme.len());
|
||||
let host_port = &after_scheme[..end];
|
||||
host_port.split('@').last().unwrap_or(host_port).split(':').next().unwrap_or(host_port).to_string()
|
||||
};
|
||||
|
||||
let max_bytes: u64 = std::env::var("HERODB_IMAGE_MAX_BYTES").ok().and_then(|s| s.parse::<u64>().ok()).unwrap_or(10 * 1024 * 1024);
|
||||
let timeout_secs: u64 = std::env::var("HERODB_IMAGE_FETCH_TIMEOUT_SECS").ok().and_then(|s| s.parse::<u64>().ok()).unwrap_or(30);
|
||||
let allowed_hosts_env = std::env::var("HERODB_IMAGE_ALLOWED_HOSTS").ok();
|
||||
if let Some(allow) = allowed_hosts_env {
|
||||
if !allow.split(',').map(|s| s.trim()).filter(|s| !s.is_empty()).any(|h| h.eq_ignore_ascii_case(&host)) {
|
||||
return Err(DBError(format!("Host '{}' not allowed for image fetch (HERODB_IMAGE_ALLOWED_HOSTS)", host)));
|
||||
}
|
||||
}
|
||||
|
||||
let agent: Agent = AgentBuilder::new()
|
||||
.timeout_read(Duration::from_secs(timeout_secs))
|
||||
.timeout_write(Duration::from_secs(timeout_secs))
|
||||
.build();
|
||||
|
||||
let resp = agent.get(uri).call().map_err(|e| DBError(format!("HTTP GET failed: {}", e)))?;
|
||||
// Validate content-type
|
||||
let ctype = resp.header("Content-Type").unwrap_or("");
|
||||
let ctype_main = ctype.split(';').next().unwrap_or("").trim().to_ascii_lowercase();
|
||||
if !ctype_main.starts_with("image/") {
|
||||
return Err(DBError(format!("Remote content-type '{}' is not image/*", ctype)));
|
||||
}
|
||||
|
||||
// Read with cap
|
||||
let mut reader = resp.into_reader();
|
||||
let mut buf: Vec<u8> = Vec::with_capacity(8192);
|
||||
let mut tmp = [0u8; 8192];
|
||||
let mut total: u64 = 0;
|
||||
loop {
|
||||
let n = reader.read(&mut tmp).map_err(|e| DBError(format!("Read error: {}", e)))?;
|
||||
if n == 0 { break; }
|
||||
total += n as u64;
|
||||
if total > max_bytes {
|
||||
return Err(DBError(format!("Image exceeds max allowed bytes {}", max_bytes)));
|
||||
}
|
||||
buf.extend_from_slice(&tmp[..n]);
|
||||
}
|
||||
Ok(buf)
|
||||
}
|
||||
|
||||
/// Check if current permissions allow read operations
|
||||
pub fn has_read_permission(&self) -> bool {
|
||||
// If an explicit permission is set for this connection, honor it.
|
||||
|
||||
Reference in New Issue
Block a user