From 2139deb85db4aedd5eb4b1c11bf8180047bd21f8 Mon Sep 17 00:00:00 2001 From: Maxime Van Hees Date: Tue, 30 Sep 2025 14:53:01 +0200 Subject: [PATCH] WIP6: implementing image embedding as first step towards multi-model support --- docs/lancedb_text_and_images_example.md | 138 +++++++++++++ src/cmd.rs | 259 ++++++++++++++++++++++++ src/crypto.rs | 6 +- src/embedding.rs | 14 -- src/lance_store.rs | 142 ++++++++++--- src/rpc.rs | 228 +++++++++++++++++++++ src/server.rs | 97 ++++++++- 7 files changed, 838 insertions(+), 46 deletions(-) create mode 100644 docs/lancedb_text_and_images_example.md diff --git a/docs/lancedb_text_and_images_example.md b/docs/lancedb_text_and_images_example.md new file mode 100644 index 0000000..d4db68c --- /dev/null +++ b/docs/lancedb_text_and_images_example.md @@ -0,0 +1,138 @@ +# LanceDB Text and Images: End-to-End Example + +This guide demonstrates creating a Lance backend database, ingesting two text documents and two images, performing searches over both, and cleaning up the datasets. + +Prerequisites +- Build HeroDB and start the server with JSON-RPC enabled. +Commands: +```bash +cargo build --release +./target/release/herodb --dir /tmp/herodb --admin-secret mysecret --port 6379 --enable-rpc +``` + +We'll use: +- redis-cli for RESP commands against port 6379 +- curl for JSON-RPC against 8080 if desired +- Deterministic local embedders to avoid external dependencies: testhash (text, dim 64) and testimagehash (image, dim 512) + +0) Create a Lance-backed database (JSON-RPC) +Request: +```json +{ "jsonrpc": "2.0", "id": 1, "method": "herodb_createDatabase", "params": ["Lance", { "name": "media-db", "storage_path": null, "max_size": null, "redis_version": null }, null] } +``` +Response returns db_id (assume 1). Select DB over RESP: +```bash +redis-cli -p 6379 SELECT 1 +# → OK +``` + +1) Configure embedding providers +We'll create two datasets with independent embedding configs: +- textset → provider testhash, dim 64 +- imageset → provider testimagehash, dim 512 + +Text config: +```bash +redis-cli -p 6379 LANCE.EMBEDDING CONFIG SET textset PROVIDER testhash MODEL any PARAM dim 64 +# → OK +``` +Image config: +```bash +redis-cli -p 6379 LANCE.EMBEDDING CONFIG SET imageset PROVIDER testimagehash MODEL any PARAM dim 512 +# → OK +``` + +2) Create datasets +```bash +redis-cli -p 6379 LANCE.CREATE textset DIM 64 +# → OK +redis-cli -p 6379 LANCE.CREATE imageset DIM 512 +# → OK +``` + +3) Ingest two text documents (server-side embedding) +```bash +redis-cli -p 6379 LANCE.STORE textset ID doc-1 TEXT "The quick brown fox jumps over the lazy dog" META title "Fox" category "animal" +# → OK +redis-cli -p 6379 LANCE.STORE textset ID doc-2 TEXT "A fast auburn fox vaulted a sleepy canine" META title "Paraphrase" category "animal" +# → OK +``` + +4) Ingest two images +You can provide a URI or base64 bytes. Use URI for URIs, BYTES for base64 data. +Example using free placeholder images: +```bash +# Store via URI +redis-cli -p 6379 LANCE.STOREIMAGE imageset ID img-1 URI "https://picsum.photos/seed/1/256/256" META title "Seed1" group "demo" +# → OK +redis-cli -p 6379 LANCE.STOREIMAGE imageset ID img-2 URI "https://picsum.photos/seed/2/256/256" META title "Seed2" group "demo" +# → OK +``` +If your environment blocks outbound HTTP, you can embed image bytes: +```bash +# Example: read a local file and base64 it (replace path) +b64=$(base64 -w0 ./image1.png) +redis-cli -p 6379 LANCE.STOREIMAGE imageset ID img-b64-1 BYTES "$b64" META title "Local1" group "demo" +``` + +5) Search text +```bash +# Top-2 nearest neighbors for a query +redis-cli -p 6379 LANCE.SEARCH textset K 2 QUERY "quick brown fox" RETURN 1 title +# → 1) [id, score, [k1,v1,...]] +``` +With a filter (supports equality on schema or meta keys): +```bash +redis-cli -p 6379 LANCE.SEARCH textset K 2 QUERY "fox jumps" FILTER "category = 'animal'" RETURN 1 title +``` + +6) Search images +```bash +# Provide a URI as the query +redis-cli -p 6379 LANCE.SEARCHIMAGE imageset K 2 QUERYURI "https://picsum.photos/seed/1/256/256" RETURN 1 title + +# Or provide base64 bytes as the query +qb64=$(curl -s https://picsum.photos/seed/3/256/256 | base64 -w0) +redis-cli -p 6379 LANCE.SEARCHIMAGE imageset K 2 QUERYBYTES "$qb64" RETURN 1 title +``` + +7) Inspect datasets +```bash +redis-cli -p 6379 LANCE.LIST +redis-cli -p 6379 LANCE.INFO textset +redis-cli -p 6379 LANCE.INFO imageset +``` + +8) Delete by id and drop datasets +```bash +# Delete one record +redis-cli -p 6379 LANCE.DEL textset doc-2 +# → OK + +# Drop entire datasets +redis-cli -p 6379 LANCE.DROP textset +redis-cli -p 6379 LANCE.DROP imageset +# → OK +``` + +Appendix: Using OpenAI embeddings instead of test providers +Text: +```bash +export OPENAI_API_KEY=sk-... +redis-cli -p 6379 LANCE.EMBEDDING CONFIG SET textset PROVIDER openai MODEL text-embedding-3-small PARAM dim 512 +redis-cli -p 6379 LANCE.CREATE textset DIM 512 +``` +Azure OpenAI: +```bash +export AZURE_OPENAI_API_KEY=... +redis-cli -p 6379 LANCE.EMBEDDING CONFIG SET textset PROVIDER openai MODEL text-embedding-3-small \ + PARAM use_azure true \ + PARAM azure_endpoint https://myresource.openai.azure.com \ + PARAM azure_deployment my-embed-deploy \ + PARAM azure_api_version 2024-02-15 \ + PARAM dim 512 +``` +Notes: +- Ensure dataset DIM matches the configured embedding dimension. +- Lance is only available for non-admin databases (db_id >= 1). +- On Lance DBs, only LANCE.* and basic control commands are allowed. \ No newline at end of file diff --git a/src/cmd.rs b/src/cmd.rs index 66c63a6..88884da 100644 --- a/src/cmd.rs +++ b/src/cmd.rs @@ -1,4 +1,5 @@ use crate::{error::DBError, protocol::Protocol, server::Server, embedding::{EmbeddingConfig, EmbeddingProvider}}; +use base64::{engine::general_purpose, Engine as _}; use tokio::time::{timeout, Duration}; use futures::future::select_all; @@ -145,6 +146,22 @@ pub enum Cmd { filter: Option, return_fields: Option>, }, + // Image-first commands (no user-provided vectors) + LanceStoreImage { + name: String, + id: String, + uri: Option, + bytes_b64: Option, + meta: Vec<(String, String)>, + }, + LanceSearchImage { + name: String, + k: usize, + uri: Option, + bytes_b64: Option, + filter: Option, + return_fields: Option>, + }, LanceCreateIndex { name: String, index_type: String, @@ -903,6 +920,46 @@ impl Cmd { } Cmd::LanceStoreText { name, id, text, meta } } + "lance.storeimage" => { + // LANCE.STOREIMAGE name ID (URI | BYTES ) [META k v ...] + if cmd.len() < 6 { + return Err(DBError("ERR LANCE.STOREIMAGE requires: name ID (URI | BYTES ) [META k v ...]".to_string())); + } + let name = cmd[1].clone(); + let mut i = 2; + if cmd[i].to_uppercase() != "ID" || i + 1 >= cmd.len() { + return Err(DBError("ERR LANCE.STOREIMAGE requires ID ".to_string())); + } + let id = cmd[i + 1].clone(); + i += 2; + + let mut uri_opt: Option = None; + let mut bytes_b64_opt: Option = None; + + if i < cmd.len() && cmd[i].to_uppercase() == "URI" { + if i + 1 >= cmd.len() { return Err(DBError("ERR LANCE.STOREIMAGE URI requires a value".to_string())); } + uri_opt = Some(cmd[i + 1].clone()); + i += 2; + } else if i < cmd.len() && cmd[i].to_uppercase() == "BYTES" { + if i + 1 >= cmd.len() { return Err(DBError("ERR LANCE.STOREIMAGE BYTES requires a value".to_string())); } + bytes_b64_opt = Some(cmd[i + 1].clone()); + i += 2; + } else { + return Err(DBError("ERR LANCE.STOREIMAGE requires either URI or BYTES ".to_string())); + } + + // Parse optional META pairs + let mut meta: Vec<(String, String)> = Vec::new(); + if i < cmd.len() && cmd[i].to_uppercase() == "META" { + i += 1; + while i + 1 < cmd.len() { + meta.push((cmd[i].clone(), cmd[i + 1].clone())); + i += 2; + } + } + + Cmd::LanceStoreImage { name, id, uri: uri_opt, bytes_b64: bytes_b64_opt, meta } + } "lance.search" => { // LANCE.SEARCH name K QUERY [FILTER expr] [RETURN n fields...] if cmd.len() < 6 { @@ -954,6 +1011,65 @@ impl Cmd { } Cmd::LanceSearchText { name, text, k, filter, return_fields } } + "lance.searchimage" => { + // LANCE.SEARCHIMAGE name K (QUERYURI | QUERYBYTES ) [FILTER expr] [RETURN n fields...] + if cmd.len() < 6 { + return Err(DBError("ERR LANCE.SEARCHIMAGE requires: name K (QUERYURI | QUERYBYTES ) [FILTER expr] [RETURN n fields...]".to_string())); + } + let name = cmd[1].clone(); + if cmd[2].to_uppercase() != "K" { + return Err(DBError("ERR LANCE.SEARCHIMAGE requires K ".to_string())); + } + let k: usize = cmd[3].parse().map_err(|_| DBError("ERR K must be an integer".to_string()))?; + let mut i = 4; + + let mut uri_opt: Option = None; + let mut bytes_b64_opt: Option = None; + + if i < cmd.len() && cmd[i].to_uppercase() == "QUERYURI" { + if i + 1 >= cmd.len() { return Err(DBError("ERR QUERYURI requires a value".to_string())); } + uri_opt = Some(cmd[i + 1].clone()); + i += 2; + } else if i < cmd.len() && cmd[i].to_uppercase() == "QUERYBYTES" { + if i + 1 >= cmd.len() { return Err(DBError("ERR QUERYBYTES requires a value".to_string())); } + bytes_b64_opt = Some(cmd[i + 1].clone()); + i += 2; + } else { + return Err(DBError("ERR LANCE.SEARCHIMAGE requires QUERYURI or QUERYBYTES ".to_string())); + } + + let mut filter: Option = None; + let mut return_fields: Option> = None; + while i < cmd.len() { + match cmd[i].to_uppercase().as_str() { + "FILTER" => { + if i + 1 >= cmd.len() { + return Err(DBError("ERR FILTER requires an expression".to_string())); + } + filter = Some(cmd[i + 1].clone()); + i += 2; + } + "RETURN" => { + if i + 1 >= cmd.len() { + return Err(DBError("ERR RETURN requires field count".to_string())); + } + let n: usize = cmd[i + 1].parse().map_err(|_| DBError("ERR RETURN count must be integer".to_string()))?; + i += 2; + let mut fields = Vec::new(); + for _ in 0..n { + if i < cmd.len() { + fields.push(cmd[i].clone()); + i += 1; + } + } + return_fields = Some(fields); + } + _ => { i += 1; } + } + } + + Cmd::LanceSearchImage { name, k, uri: uri_opt, bytes_b64: bytes_b64_opt, filter, return_fields } + } "lance.createindex" => { // LANCE.CREATEINDEX name TYPE t [PARAM k v ...] if cmd.len() < 4 || cmd[2].to_uppercase() != "TYPE" { @@ -1136,6 +1252,8 @@ impl Cmd { | Cmd::LanceCreate { .. } | Cmd::LanceStoreText { .. } | Cmd::LanceSearchText { .. } + | Cmd::LanceStoreImage { .. } + | Cmd::LanceSearchImage { .. } | Cmd::LanceEmbeddingConfigSet { .. } | Cmd::LanceEmbeddingConfigGet { .. } | Cmd::LanceCreateIndex { .. } @@ -1172,6 +1290,8 @@ impl Cmd { Cmd::LanceCreate { .. } | Cmd::LanceStoreText { .. } | Cmd::LanceSearchText { .. } + | Cmd::LanceStoreImage { .. } + | Cmd::LanceSearchImage { .. } | Cmd::LanceEmbeddingConfigSet { .. } | Cmd::LanceEmbeddingConfigGet { .. } | Cmd::LanceCreateIndex { .. } @@ -1421,6 +1541,145 @@ impl Cmd { Err(e) => Ok(Protocol::err(&e.0)), } } + + // New: Image store + Cmd::LanceStoreImage { name, id, uri, bytes_b64, meta } => { + if !server.has_write_permission() { + return Ok(Protocol::err("ERR write permission denied")); + } + let use_uri = uri.is_some(); + let use_b64 = bytes_b64.is_some(); + if (use_uri && use_b64) || (!use_uri && !use_b64) { + return Ok(Protocol::err("ERR Provide exactly one of URI or BYTES for LANCE.STOREIMAGE")); + } + let max_bytes: usize = std::env::var("HERODB_IMAGE_MAX_BYTES") + .ok() + .and_then(|s| s.parse::().ok()) + .unwrap_or(10 * 1024 * 1024) as usize; + + let media_uri_opt = if let Some(u) = uri.clone() { + match server.fetch_image_bytes_from_uri(&u) { + Ok(_) => {} + Err(e) => return Ok(Protocol::err(&e.0)), + } + Some(u) + } else { + None + }; + + let bytes: Vec = if let Some(u) = uri { + match server.fetch_image_bytes_from_uri(&u) { + Ok(b) => b, + Err(e) => return Ok(Protocol::err(&e.0)), + } + } else { + let b64 = bytes_b64.unwrap_or_default(); + let data = match general_purpose::STANDARD.decode(b64.as_bytes()) { + Ok(d) => d, + Err(e) => return Ok(Protocol::err(&format!("ERR base64 decode error: {}", e))), + }; + if data.len() > max_bytes { + return Ok(Protocol::err(&format!("ERR image exceeds max allowed bytes {}", max_bytes))); + } + data + }; + + let img_embedder = match server.get_image_embedder_for(&name) { + Ok(e) => e, + Err(e) => return Ok(Protocol::err(&e.0)), + }; + let (tx, rx) = tokio::sync::oneshot::channel(); + let emb_arc = img_embedder.clone(); + let bytes_cl = bytes.clone(); + std::thread::spawn(move || { + let res = emb_arc.embed_image(&bytes_cl); + let _ = tx.send(res); + }); + let vector = match rx.await { + Ok(Ok(v)) => v, + Ok(Err(e)) => return Ok(Protocol::err(&e.0)), + Err(recv_err) => return Ok(Protocol::err(&format!("ERR embedding thread error: {}", recv_err))), + }; + + let meta_map: std::collections::HashMap = meta.into_iter().collect(); + match server.lance_store()?.store_vector_with_media( + &name, + &id, + vector, + meta_map, + None, + Some("image".to_string()), + media_uri_opt, + ).await { + Ok(()) => Ok(Protocol::SimpleString("OK".to_string())), + Err(e) => Ok(Protocol::err(&e.0)), + } + } + + // New: Image search + Cmd::LanceSearchImage { name, k, uri, bytes_b64, filter, return_fields } => { + let use_uri = uri.is_some(); + let use_b64 = bytes_b64.is_some(); + if (use_uri && use_b64) || (!use_uri && !use_b64) { + return Ok(Protocol::err("ERR Provide exactly one of QUERYURI or QUERYBYTES for LANCE.SEARCHIMAGE")); + } + let max_bytes: usize = std::env::var("HERODB_IMAGE_MAX_BYTES") + .ok() + .and_then(|s| s.parse::().ok()) + .unwrap_or(10 * 1024 * 1024) as usize; + + let bytes: Vec = if let Some(u) = uri { + match server.fetch_image_bytes_from_uri(&u) { + Ok(b) => b, + Err(e) => return Ok(Protocol::err(&e.0)), + } + } else { + let b64 = bytes_b64.unwrap_or_default(); + let data = match general_purpose::STANDARD.decode(b64.as_bytes()) { + Ok(d) => d, + Err(e) => return Ok(Protocol::err(&format!("ERR base64 decode error: {}", e))), + }; + if data.len() > max_bytes { + return Ok(Protocol::err(&format!("ERR image exceeds max allowed bytes {}", max_bytes))); + } + data + }; + + let img_embedder = match server.get_image_embedder_for(&name) { + Ok(e) => e, + Err(e) => return Ok(Protocol::err(&e.0)), + }; + let (tx, rx) = tokio::sync::oneshot::channel(); + std::thread::spawn(move || { + let res = img_embedder.embed_image(&bytes); + let _ = tx.send(res); + }); + let qv = match rx.await { + Ok(Ok(v)) => v, + Ok(Err(e)) => return Ok(Protocol::err(&e.0)), + Err(recv_err) => return Ok(Protocol::err(&format!("ERR embedding thread error: {}", recv_err))), + }; + + match server.lance_store()?.search_vectors(&name, qv, k, filter, return_fields).await { + Ok(results) => { + let mut arr = Vec::new(); + for (id, score, meta) in results { + let mut meta_arr: Vec = Vec::new(); + for (k, v) in meta { + meta_arr.push(Protocol::BulkString(k)); + meta_arr.push(Protocol::BulkString(v)); + } + arr.push(Protocol::Array(vec![ + Protocol::BulkString(id), + Protocol::BulkString(score.to_string()), + Protocol::Array(meta_arr), + ])); + } + Ok(Protocol::Array(arr)) + } + Err(e) => Ok(Protocol::err(&e.0)), + } + } Cmd::LanceCreateIndex { name, index_type, params } => { if !server.has_write_permission() { return Ok(Protocol::err("ERR write permission denied")); diff --git a/src/crypto.rs b/src/crypto.rs index 48a9f8c..b63fb53 100644 --- a/src/crypto.rs +++ b/src/crypto.rs @@ -1,8 +1,8 @@ use chacha20poly1305::{ - aead::{Aead, KeyInit, OsRng}, + aead::{Aead, KeyInit}, XChaCha20Poly1305, XNonce, }; -use rand::RngCore; +use rand::{rngs::OsRng, RngCore}; use sha2::{Digest, Sha256}; const VERSION: u8 = 1; @@ -31,7 +31,7 @@ pub struct CryptoFactory { impl CryptoFactory { /// Accepts any secret bytes; turns them into a 32-byte key (SHA-256). pub fn new>(secret: S) -> Self { - let mut h = Sha256::new(); + let mut h = Sha256::default(); h.update(b"xchacha20poly1305-factory:v1"); // domain separation h.update(secret.as_ref()); let digest = h.finalize(); // 32 bytes diff --git a/src/embedding.rs b/src/embedding.rs index 36e2e99..79690d4 100644 --- a/src/embedding.rs +++ b/src/embedding.rs @@ -1,18 +1,4 @@ // Embedding abstraction and minimal providers. -// -// This module defines a provider-agnostic interface to produce vector embeddings -// from text, so callers never need to supply vectors manually. It includes: -// - Embedder trait -// - EmbeddingProvider and EmbeddingConfig (serde-serializable) -// - TestHashEmbedder: deterministic, CPU-only, no-network embedder suitable for CI -// - Factory create_embedder(..) to instantiate an embedder from config -// -// Integration plan: -// - Server will resolve per-dataset EmbeddingConfig from sidecar JSON and cache Arc -// - LanceStore will call embedder.embed(text) then persist id, vector, text, meta -// -// Note: Real LanceDB-backed embedding providers can be added by implementing Embedder -// and extending create_embedder(..). This file keeps no direct dependency on LanceDB. use std::collections::HashMap; use std::sync::Arc; diff --git a/src/lance_store.rs b/src/lance_store.rs index 78f1923..0e1b085 100644 --- a/src/lance_store.rs +++ b/src/lance_store.rs @@ -112,6 +112,8 @@ impl LanceStore { Field::new("id", DataType::Utf8, false), Self::vector_field(dim), Field::new("text", DataType::Utf8, true), + Field::new("media_type", DataType::Utf8, true), + Field::new("media_uri", DataType::Utf8, true), Field::new("meta", DataType::Utf8, true), ])) } @@ -121,6 +123,8 @@ impl LanceStore { vector: &[f32], meta: &HashMap, text: Option<&str>, + media_type: Option<&str>, + media_uri: Option<&str>, dim: i32, ) -> Result<(Arc, RecordBatch), DBError> { if vector.len() as i32 != dim { @@ -156,6 +160,24 @@ impl LanceStore { } let text_arr = Arc::new(text_builder.finish()) as Arc; + // media_type column (optional) + let mut mt_builder = StringBuilder::new(); + if let Some(mt) = media_type { + mt_builder.append_value(mt); + } else { + mt_builder.append_null(); + } + let mt_arr = Arc::new(mt_builder.finish()) as Arc; + + // media_uri column (optional) + let mut mu_builder = StringBuilder::new(); + if let Some(mu) = media_uri { + mu_builder.append_value(mu); + } else { + mu_builder.append_null(); + } + let mu_arr = Arc::new(mu_builder.finish()) as Arc; + // meta column (JSON string) let meta_json = if meta.is_empty() { None @@ -171,7 +193,7 @@ impl LanceStore { let meta_arr = Arc::new(meta_builder.finish()) as Arc; let batch = - RecordBatch::try_new(schema.clone(), vec![id_arr, vec_arr, text_arr, meta_arr]).map_err(|e| { + RecordBatch::try_new(schema.clone(), vec![id_arr, vec_arr, text_arr, mt_arr, mu_arr, meta_arr]).map_err(|e| { DBError(format!("RecordBatch build failed: {e}")) })?; @@ -207,10 +229,12 @@ impl LanceStore { let mut list_builder = FixedSizeListBuilder::new(v_builder, dim_i32); let empty_vec = Arc::new(list_builder.finish()) as Arc; let empty_text = Arc::new(StringArray::new_null(0)); + let empty_media_type = Arc::new(StringArray::new_null(0)); + let empty_media_uri = Arc::new(StringArray::new_null(0)); let empty_meta = Arc::new(StringArray::new_null(0)); let empty_batch = - RecordBatch::try_new(schema.clone(), vec![empty_id, empty_vec, empty_text, empty_meta]) + RecordBatch::try_new(schema.clone(), vec![empty_id, empty_vec, empty_text, empty_media_type, empty_media_uri, empty_meta]) .map_err(|e| DBError(format!("Build empty batch failed: {e}")))?; let write_params = WriteParams { @@ -235,6 +259,21 @@ impl LanceStore { vector: Vec, meta: HashMap, text: Option, + ) -> Result<(), DBError> { + // Delegate to media-aware path with no media fields + self.store_vector_with_media(name, id, vector, meta, text, None, None).await + } + + /// Store/Upsert a single vector with optional text and media fields (media_type/media_uri). + pub async fn store_vector_with_media( + &self, + name: &str, + id: &str, + vector: Vec, + meta: HashMap, + text: Option, + media_type: Option, + media_uri: Option, ) -> Result<(), DBError> { let path = self.dataset_path(name); @@ -248,7 +287,15 @@ impl LanceStore { .map_err(|_| DBError("Vector length too large".into()))? }; - let (schema, batch) = Self::build_one_row_batch(id, &vector, &meta, text.as_deref(), dim_i32)?; + let (schema, batch) = Self::build_one_row_batch( + id, + &vector, + &meta, + text.as_deref(), + media_type.as_deref(), + media_uri.as_deref(), + dim_i32, + )?; // If LanceDB table exists and provides delete, we can upsert by deleting same id // Try best-effort delete; ignore errors to keep operation append-only on failure @@ -355,21 +402,36 @@ impl LanceStore { .await .map_err(|e| DBError(format!("Open dataset failed: {}", e)))?; - // Build scanner with projection; filter if provided + // Build scanner with projection; we project needed fields and filter client-side to support meta keys let mut scan = ds.scan(); - if let Err(e) = scan.project(&["id", "vector", "meta"]) { + if let Err(e) = scan.project(&["id", "vector", "meta", "text", "media_type", "media_uri"]) { return Err(DBError(format!("Project failed: {}", e))); } - if let Some(pred) = filter { - if let Err(e) = scan.filter(&pred) { - return Err(DBError(format!("Filter failed: {}", e))); - } - } + // Note: we no longer push down filter to Lance to allow filtering on meta fields client-side. let mut stream = scan .try_into_stream() .await .map_err(|e| DBError(format!("Scan stream failed: {}", e)))?; + + // Parse simple equality clause from filter for client-side filtering (supports one `key = 'value'`) + let clause = filter.as_ref().and_then(|s| { + fn parse_eq(s: &str) -> Option<(String, String)> { + let s = s.trim(); + let pos = s.find('=').or_else(|| s.find(" = "))?; + let (k, vraw) = s.split_at(pos); + let mut v = vraw.trim_start_matches('=').trim(); + if (v.starts_with('\'') && v.ends_with('\'')) || (v.starts_with('"') && v.ends_with('"')) { + if v.len() >= 2 { + v = &v[1..v.len()-1]; + } + } + let key = k.trim().trim_matches('"').trim_matches('\'').to_string(); + if key.is_empty() { return None; } + Some((key, v.to_string())) + } + parse_eq(s) + }); // Maintain a max-heap with reverse ordering to keep top-k smallest distances #[derive(Debug)] @@ -412,20 +474,18 @@ impl LanceStore { let meta_arr = batch .column_by_name("meta") .map(|a| a.as_string::().clone()); + let text_arr = batch + .column_by_name("text") + .map(|a| a.as_string::().clone()); + let mt_arr = batch + .column_by_name("media_type") + .map(|a| a.as_string::().clone()); + let mu_arr = batch + .column_by_name("media_uri") + .map(|a| a.as_string::().clone()); for i in 0..batch.num_rows() { - // Compute L2 distance - let val = vec_arr.value(i); - let prim = val.as_primitive::(); - let mut dist: f32 = 0.0; - let plen = prim.len(); - for j in 0..plen { - let r = prim.value(j); - let d = query[j] - r; - dist += d * d; - } - - // Parse id + // Extract id let id_val = id_arr.value(i).to_string(); // Parse meta JSON if present @@ -439,26 +499,54 @@ impl LanceStore { meta.insert(k, vs.to_string()); } else if v.is_number() || v.is_boolean() { meta.insert(k, v.to_string()); - } else { - // skip complex entries } } } } } + // Evaluate simple equality filter if provided (supports one clause) + let passes = if let Some((ref key, ref val)) = clause { + let candidate = match key.as_str() { + "id" => Some(id_val.clone()), + "text" => text_arr.as_ref().and_then(|col| if col.is_null(i) { None } else { Some(col.value(i).to_string()) }), + "media_type" => mt_arr.as_ref().and_then(|col| if col.is_null(i) { None } else { Some(col.value(i).to_string()) }), + "media_uri" => mu_arr.as_ref().and_then(|col| if col.is_null(i) { None } else { Some(col.value(i).to_string()) }), + _ => meta.get(key).cloned(), + }; + match candidate { + Some(cv) => cv == *val, + None => false, + } + } else { true }; + if !passes { + continue; + } + + // Compute L2 distance + let val = vec_arr.value(i); + let prim = val.as_primitive::(); + let mut dist: f32 = 0.0; + let plen = prim.len(); + for j in 0..plen { + let r = prim.value(j); + let d = query[j] - r; + dist += d * d; + } + // Apply return_fields on meta + let mut meta_out = meta; if let Some(fields) = &return_fields { let mut filtered = HashMap::new(); for f in fields { - if let Some(val) = meta.get(f) { + if let Some(val) = meta_out.get(f) { filtered.insert(f.clone(), val.clone()); } } - meta = filtered; + meta_out = filtered; } - let hit = Hit { dist, id: id_val, meta }; + let hit = Hit { dist, id: id_val, meta: meta_out }; if heap.len() < k { heap.push(hit); diff --git a/src/rpc.rs b/src/rpc.rs index b589efd..609014f 100644 --- a/src/rpc.rs +++ b/src/rpc.rs @@ -10,6 +10,7 @@ use crate::server::Server; use crate::options::DBOption; use crate::admin_meta; use crate::embedding::{EmbeddingConfig, EmbeddingProvider}; +use base64::{engine::general_purpose, Engine as _}; /// Database backend types #[derive(Debug, Clone, Serialize, Deserialize)] @@ -282,6 +283,33 @@ pub trait Rpc { filter: Option, return_fields: Option>, ) -> RpcResult; + + // ----- Image-first endpoints (no user-provided vectors) ----- + + /// Store an image; exactly one of uri or bytes_b64 must be provided. + #[method(name = "lanceStoreImage")] + async fn lance_store_image( + &self, + db_id: u64, + name: String, + id: String, + uri: Option, + bytes_b64: Option, + meta: Option>, + ) -> RpcResult; + + /// Search using an image query; exactly one of uri or bytes_b64 must be provided. + #[method(name = "lanceSearchImage")] + async fn lance_search_image( + &self, + db_id: u64, + name: String, + k: usize, + uri: Option, + bytes_b64: Option, + filter: Option, + return_fields: Option>, + ) -> RpcResult; } /// RPC Server implementation @@ -1131,4 +1159,204 @@ impl RpcServer for RpcServerImpl { Ok(serde_json::json!({ "results": json_results })) } + + // ----- New image-first Lance RPC implementations ----- + + async fn lance_store_image( + &self, + db_id: u64, + name: String, + id: String, + uri: Option, + bytes_b64: Option, + meta: Option>, + ) -> RpcResult { + let server = self.get_or_create_server(db_id).await?; + if db_id == 0 { + return Err(jsonrpsee::types::ErrorObjectOwned::owned(-32000, "Lance not allowed on DB 0", None::<()>)); + } + if !matches!(server.option.backend, crate::options::BackendType::Lance) { + return Err(jsonrpsee::types::ErrorObjectOwned::owned(-32000, "DB backend is not Lance", None::<()>)); + } + if !server.has_write_permission() { + return Err(jsonrpsee::types::ErrorObjectOwned::owned(-32000, "write permission denied", None::<()>)); + } + + // Validate exactly one of uri or bytes_b64 + let (use_uri, use_b64) = (uri.is_some(), bytes_b64.is_some()); + if (use_uri && use_b64) || (!use_uri && !use_b64) { + return Err(jsonrpsee::types::ErrorObjectOwned::owned( + -32000, + "Provide exactly one of 'uri' or 'bytes_b64'", + None::<()>, + )); + } + + // Acquire image bytes (with caps) + let max_bytes: usize = std::env::var("HERODB_IMAGE_MAX_BYTES") + .ok() + .and_then(|s| s.parse::().ok()) + .unwrap_or(10 * 1024 * 1024) as usize; + + let (bytes, media_uri_opt) = if let Some(u) = uri.clone() { + let data = server + .fetch_image_bytes_from_uri(&u) + .map_err(|e| jsonrpsee::types::ErrorObjectOwned::owned(-32000, e.0, None::<()>))?; + (data, Some(u)) + } else { + let b64 = bytes_b64.unwrap_or_default(); + let data = general_purpose::STANDARD + .decode(b64.as_bytes()) + .map_err(|e| jsonrpsee::types::ErrorObjectOwned::owned(-32000, format!("base64 decode error: {}", e), None::<()>))?; + if data.len() > max_bytes { + return Err(jsonrpsee::types::ErrorObjectOwned::owned( + -32000, + format!("Image exceeds max allowed bytes {}", max_bytes), + None::<()>, + )); + } + (data, None) + }; + + // Resolve image embedder and embed on a plain OS thread + let img_embedder = server + .get_image_embedder_for(&name) + .map_err(|e| jsonrpsee::types::ErrorObjectOwned::owned(-32000, e.0, None::<()>))?; + let (tx, rx) = tokio::sync::oneshot::channel(); + let emb_arc = img_embedder.clone(); + let bytes_cl = bytes.clone(); + std::thread::spawn(move || { + let res = emb_arc.embed_image(&bytes_cl); + let _ = tx.send(res); + }); + let vector = match rx.await { + Ok(Ok(v)) => v, + Ok(Err(e)) => return Err(jsonrpsee::types::ErrorObjectOwned::owned(-32000, e.0, None::<()>)), + Err(recv_err) => { + return Err(jsonrpsee::types::ErrorObjectOwned::owned( + -32000, + format!("embedding thread error: {}", recv_err), + None::<()>, + )) + } + }; + + // Store vector with media fields + server + .lance_store() + .map_err(|e| jsonrpsee::types::ErrorObjectOwned::owned(-32000, e.0, None::<()>))? + .store_vector_with_media( + &name, + &id, + vector, + meta.unwrap_or_default(), + None, + Some("image".to_string()), + media_uri_opt, + ) + .await + .map_err(|e| jsonrpsee::types::ErrorObjectOwned::owned(-32000, e.0, None::<()>))?; + + Ok(true) + } + + async fn lance_search_image( + &self, + db_id: u64, + name: String, + k: usize, + uri: Option, + bytes_b64: Option, + filter: Option, + return_fields: Option>, + ) -> RpcResult { + let server = self.get_or_create_server(db_id).await?; + if db_id == 0 { + return Err(jsonrpsee::types::ErrorObjectOwned::owned(-32000, "Lance not allowed on DB 0", None::<()>)); + } + if !matches!(server.option.backend, crate::options::BackendType::Lance) { + return Err(jsonrpsee::types::ErrorObjectOwned::owned(-32000, "DB backend is not Lance", None::<()>)); + } + if !server.has_read_permission() { + return Err(jsonrpsee::types::ErrorObjectOwned::owned(-32000, "read permission denied", None::<()>)); + } + + // Validate exactly one of uri or bytes_b64 + let (use_uri, use_b64) = (uri.is_some(), bytes_b64.is_some()); + if (use_uri && use_b64) || (!use_uri && !use_b64) { + return Err(jsonrpsee::types::ErrorObjectOwned::owned( + -32000, + "Provide exactly one of 'uri' or 'bytes_b64'", + None::<()>, + )); + } + + // Acquire image bytes for query (with caps) + let max_bytes: usize = std::env::var("HERODB_IMAGE_MAX_BYTES") + .ok() + .and_then(|s| s.parse::().ok()) + .unwrap_or(10 * 1024 * 1024) as usize; + + let bytes = if let Some(u) = uri { + server + .fetch_image_bytes_from_uri(&u) + .map_err(|e| jsonrpsee::types::ErrorObjectOwned::owned(-32000, e.0, None::<()>))? + } else { + let b64 = bytes_b64.unwrap_or_default(); + let data = general_purpose::STANDARD + .decode(b64.as_bytes()) + .map_err(|e| jsonrpsee::types::ErrorObjectOwned::owned(-32000, format!("base64 decode error: {}", e), None::<()>))?; + if data.len() > max_bytes { + return Err(jsonrpsee::types::ErrorObjectOwned::owned( + -32000, + format!("Image exceeds max allowed bytes {}", max_bytes), + None::<()>, + )); + } + data + }; + + // Resolve image embedder and embed on OS thread + let img_embedder = server + .get_image_embedder_for(&name) + .map_err(|e| jsonrpsee::types::ErrorObjectOwned::owned(-32000, e.0, None::<()>))?; + let (tx, rx) = tokio::sync::oneshot::channel(); + let emb_arc = img_embedder.clone(); + std::thread::spawn(move || { + let res = emb_arc.embed_image(&bytes); + let _ = tx.send(res); + }); + let qv = match rx.await { + Ok(Ok(v)) => v, + Ok(Err(e)) => return Err(jsonrpsee::types::ErrorObjectOwned::owned(-32000, e.0, None::<()>)), + Err(recv_err) => { + return Err(jsonrpsee::types::ErrorObjectOwned::owned( + -32000, + format!("embedding thread error: {}", recv_err), + None::<()>, + )) + } + }; + + // KNN search and return results + let results = server + .lance_store() + .map_err(|e| jsonrpsee::types::ErrorObjectOwned::owned(-32000, e.0, None::<()>))? + .search_vectors(&name, qv, k, filter, return_fields) + .await + .map_err(|e| jsonrpsee::types::ErrorObjectOwned::owned(-32000, e.0, None::<()>))?; + + let json_results: Vec = results + .into_iter() + .map(|(id, score, meta)| { + serde_json::json!({ + "id": id, + "score": score, + "meta": meta, + }) + }) + .collect(); + + Ok(serde_json::json!({ "results": json_results })) + } } \ No newline at end of file diff --git a/src/server.rs b/src/server.rs index 31fe5b7..1aa8cf5 100644 --- a/src/server.rs +++ b/src/server.rs @@ -15,8 +15,11 @@ use crate::storage_trait::StorageBackend; use crate::admin_meta; // Embeddings: config and cache -use crate::embedding::{EmbeddingConfig, create_embedder, Embedder}; +use crate::embedding::{EmbeddingConfig, create_embedder, Embedder, create_image_embedder, ImageEmbedder}; use serde_json; +use ureq::{Agent, AgentBuilder}; +use std::time::Duration; +use std::io::Read; #[derive(Clone)] pub struct Server { @@ -33,9 +36,12 @@ pub struct Server { // Per-DB Lance stores (vector DB), keyed by db_id pub lance_stores: Arc>>>, - // Per-(db_id, dataset) embedder cache + // Per-(db_id, dataset) embedder cache (text) pub embedders: Arc>>>, + // Per-(db_id, dataset) image embedder cache (image) + pub image_embedders: Arc>>>, + // BLPOP waiter registry: per (db_index, key) FIFO of waiters pub list_waiters: Arc>>>>, pub waiter_seq: Arc, @@ -66,6 +72,7 @@ impl Server { search_indexes: Arc::new(std::sync::RwLock::new(HashMap::new())), lance_stores: Arc::new(std::sync::RwLock::new(HashMap::new())), embedders: Arc::new(std::sync::RwLock::new(HashMap::new())), + image_embedders: Arc::new(std::sync::RwLock::new(HashMap::new())), list_waiters: Arc::new(Mutex::new(HashMap::new())), waiter_seq: Arc::new(AtomicU64::new(1)), } @@ -189,6 +196,10 @@ impl Server { let mut map = self.embedders.write().unwrap(); map.remove(&(self.selected_db, dataset.to_string())); } + { + let mut map_img = self.image_embedders.write().unwrap(); + map_img.remove(&(self.selected_db, dataset.to_string())); + } Ok(()) } @@ -233,6 +244,88 @@ impl Server { Ok(emb) } + /// Resolve or build an IMAGE embedder for (db_id, dataset). Caches instance. + pub fn get_image_embedder_for(&self, dataset: &str) -> Result, DBError> { + if self.selected_db == 0 { + return Err(DBError("Lance not available on admin DB 0".to_string())); + } + // Fast path + { + let map = self.image_embedders.read().unwrap(); + if let Some(e) = map.get(&(self.selected_db, dataset.to_string())) { + return Ok(e.clone()); + } + } + // Load config and instantiate + let cfg = self.get_dataset_embedding_config(dataset)?; + let emb = create_image_embedder(&cfg)?; + { + let mut map = self.image_embedders.write().unwrap(); + map.insert((self.selected_db, dataset.to_string()), emb.clone()); + } + Ok(emb) + } + + /// Download image bytes from a URI with safety checks (size, timeout, content-type, optional host allowlist). + /// Env overrides: + /// - HERODB_IMAGE_MAX_BYTES (u64, default 10485760) + /// - HERODB_IMAGE_FETCH_TIMEOUT_SECS (u64, default 30) + /// - HERODB_IMAGE_ALLOWED_HOSTS (comma-separated, optional) + pub fn fetch_image_bytes_from_uri(&self, uri: &str) -> Result, DBError> { + // Basic scheme validation + if !(uri.starts_with("http://") || uri.starts_with("https://")) { + return Err(DBError("Only http(s) URIs are supported for image fetch".into())); + } + // Parse host (naive) for allowlist check + let host = { + let after_scheme = match uri.find("://") { + Some(i) => &uri[i + 3..], + None => uri, + }; + let end = after_scheme.find('/').unwrap_or(after_scheme.len()); + let host_port = &after_scheme[..end]; + host_port.split('@').last().unwrap_or(host_port).split(':').next().unwrap_or(host_port).to_string() + }; + + let max_bytes: u64 = std::env::var("HERODB_IMAGE_MAX_BYTES").ok().and_then(|s| s.parse::().ok()).unwrap_or(10 * 1024 * 1024); + let timeout_secs: u64 = std::env::var("HERODB_IMAGE_FETCH_TIMEOUT_SECS").ok().and_then(|s| s.parse::().ok()).unwrap_or(30); + let allowed_hosts_env = std::env::var("HERODB_IMAGE_ALLOWED_HOSTS").ok(); + if let Some(allow) = allowed_hosts_env { + if !allow.split(',').map(|s| s.trim()).filter(|s| !s.is_empty()).any(|h| h.eq_ignore_ascii_case(&host)) { + return Err(DBError(format!("Host '{}' not allowed for image fetch (HERODB_IMAGE_ALLOWED_HOSTS)", host))); + } + } + + let agent: Agent = AgentBuilder::new() + .timeout_read(Duration::from_secs(timeout_secs)) + .timeout_write(Duration::from_secs(timeout_secs)) + .build(); + + let resp = agent.get(uri).call().map_err(|e| DBError(format!("HTTP GET failed: {}", e)))?; + // Validate content-type + let ctype = resp.header("Content-Type").unwrap_or(""); + let ctype_main = ctype.split(';').next().unwrap_or("").trim().to_ascii_lowercase(); + if !ctype_main.starts_with("image/") { + return Err(DBError(format!("Remote content-type '{}' is not image/*", ctype))); + } + + // Read with cap + let mut reader = resp.into_reader(); + let mut buf: Vec = Vec::with_capacity(8192); + let mut tmp = [0u8; 8192]; + let mut total: u64 = 0; + loop { + let n = reader.read(&mut tmp).map_err(|e| DBError(format!("Read error: {}", e)))?; + if n == 0 { break; } + total += n as u64; + if total > max_bytes { + return Err(DBError(format!("Image exceeds max allowed bytes {}", max_bytes))); + } + buf.extend_from_slice(&tmp[..n]); + } + Ok(buf) + } + /// Check if current permissions allow read operations pub fn has_read_permission(&self) -> bool { // If an explicit permission is set for this connection, honor it.