WIP6: implementing image embedding as first step towards multi-model support
This commit is contained in:
		
							
								
								
									
										138
									
								
								docs/lancedb_text_and_images_example.md
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										138
									
								
								docs/lancedb_text_and_images_example.md
									
									
									
									
									
										Normal file
									
								
							| @@ -0,0 +1,138 @@ | ||||
| # LanceDB Text and Images: End-to-End Example | ||||
|  | ||||
| This guide demonstrates creating a Lance backend database, ingesting two text documents and two images, performing searches over both, and cleaning up the datasets. | ||||
|  | ||||
| Prerequisites | ||||
| - Build HeroDB and start the server with JSON-RPC enabled. | ||||
| Commands: | ||||
| ```bash | ||||
| cargo build --release | ||||
| ./target/release/herodb --dir /tmp/herodb --admin-secret mysecret --port 6379 --enable-rpc | ||||
| ``` | ||||
|  | ||||
| We'll use: | ||||
| - redis-cli for RESP commands against port 6379 | ||||
| - curl for JSON-RPC against 8080 if desired | ||||
| - Deterministic local embedders to avoid external dependencies: testhash (text, dim 64) and testimagehash (image, dim 512) | ||||
|  | ||||
| 0) Create a Lance-backed database (JSON-RPC) | ||||
| Request: | ||||
| ```json | ||||
| { "jsonrpc": "2.0", "id": 1, "method": "herodb_createDatabase", "params": ["Lance", { "name": "media-db", "storage_path": null, "max_size": null, "redis_version": null }, null] } | ||||
| ``` | ||||
| Response returns db_id (assume 1). Select DB over RESP: | ||||
| ```bash | ||||
| redis-cli -p 6379 SELECT 1 | ||||
| # → OK | ||||
| ``` | ||||
|  | ||||
| 1) Configure embedding providers | ||||
| We'll create two datasets with independent embedding configs: | ||||
| - textset → provider testhash, dim 64 | ||||
| - imageset → provider testimagehash, dim 512 | ||||
|  | ||||
| Text config: | ||||
| ```bash | ||||
| redis-cli -p 6379 LANCE.EMBEDDING CONFIG SET textset PROVIDER testhash MODEL any PARAM dim 64 | ||||
| # → OK | ||||
| ``` | ||||
| Image config: | ||||
| ```bash | ||||
| redis-cli -p 6379 LANCE.EMBEDDING CONFIG SET imageset PROVIDER testimagehash MODEL any PARAM dim 512 | ||||
| # → OK | ||||
| ``` | ||||
|  | ||||
| 2) Create datasets | ||||
| ```bash | ||||
| redis-cli -p 6379 LANCE.CREATE textset DIM 64 | ||||
| # → OK | ||||
| redis-cli -p 6379 LANCE.CREATE imageset DIM 512 | ||||
| # → OK | ||||
| ``` | ||||
|  | ||||
| 3) Ingest two text documents (server-side embedding) | ||||
| ```bash | ||||
| redis-cli -p 6379 LANCE.STORE textset ID doc-1 TEXT "The quick brown fox jumps over the lazy dog" META title "Fox" category "animal" | ||||
| # → OK | ||||
| redis-cli -p 6379 LANCE.STORE textset ID doc-2 TEXT "A fast auburn fox vaulted a sleepy canine" META title "Paraphrase" category "animal" | ||||
| # → OK | ||||
| ``` | ||||
|  | ||||
| 4) Ingest two images | ||||
| You can provide a URI or base64 bytes. Use URI for URIs, BYTES for base64 data. | ||||
| Example using free placeholder images: | ||||
| ```bash | ||||
| # Store via URI | ||||
| redis-cli -p 6379 LANCE.STOREIMAGE imageset ID img-1 URI "https://picsum.photos/seed/1/256/256" META title "Seed1" group "demo" | ||||
| # → OK | ||||
| redis-cli -p 6379 LANCE.STOREIMAGE imageset ID img-2 URI "https://picsum.photos/seed/2/256/256" META title "Seed2" group "demo" | ||||
| # → OK | ||||
| ``` | ||||
| If your environment blocks outbound HTTP, you can embed image bytes: | ||||
| ```bash | ||||
| # Example: read a local file and base64 it (replace path) | ||||
| b64=$(base64 -w0 ./image1.png) | ||||
| redis-cli -p 6379 LANCE.STOREIMAGE imageset ID img-b64-1 BYTES "$b64" META title "Local1" group "demo" | ||||
| ``` | ||||
|  | ||||
| 5) Search text | ||||
| ```bash | ||||
| # Top-2 nearest neighbors for a query | ||||
| redis-cli -p 6379 LANCE.SEARCH textset K 2 QUERY "quick brown fox" RETURN 1 title | ||||
| # → 1) [id, score, [k1,v1,...]] | ||||
| ``` | ||||
| With a filter (supports equality on schema or meta keys): | ||||
| ```bash | ||||
| redis-cli -p 6379 LANCE.SEARCH textset K 2 QUERY "fox jumps" FILTER "category = 'animal'" RETURN 1 title | ||||
| ``` | ||||
|  | ||||
| 6) Search images | ||||
| ```bash | ||||
| # Provide a URI as the query | ||||
| redis-cli -p 6379 LANCE.SEARCHIMAGE imageset K 2 QUERYURI "https://picsum.photos/seed/1/256/256" RETURN 1 title | ||||
|  | ||||
| # Or provide base64 bytes as the query | ||||
| qb64=$(curl -s https://picsum.photos/seed/3/256/256 | base64 -w0) | ||||
| redis-cli -p 6379 LANCE.SEARCHIMAGE imageset K 2 QUERYBYTES "$qb64" RETURN 1 title | ||||
| ``` | ||||
|  | ||||
| 7) Inspect datasets | ||||
| ```bash | ||||
| redis-cli -p 6379 LANCE.LIST | ||||
| redis-cli -p 6379 LANCE.INFO textset | ||||
| redis-cli -p 6379 LANCE.INFO imageset | ||||
| ``` | ||||
|  | ||||
| 8) Delete by id and drop datasets | ||||
| ```bash | ||||
| # Delete one record | ||||
| redis-cli -p 6379 LANCE.DEL textset doc-2 | ||||
| # → OK | ||||
|  | ||||
| # Drop entire datasets | ||||
| redis-cli -p 6379 LANCE.DROP textset | ||||
| redis-cli -p 6379 LANCE.DROP imageset | ||||
| # → OK | ||||
| ``` | ||||
|  | ||||
| Appendix: Using OpenAI embeddings instead of test providers | ||||
| Text: | ||||
| ```bash | ||||
| export OPENAI_API_KEY=sk-... | ||||
| redis-cli -p 6379 LANCE.EMBEDDING CONFIG SET textset PROVIDER openai MODEL text-embedding-3-small PARAM dim 512 | ||||
| redis-cli -p 6379 LANCE.CREATE textset DIM 512 | ||||
| ``` | ||||
| Azure OpenAI: | ||||
| ```bash | ||||
| export AZURE_OPENAI_API_KEY=... | ||||
| redis-cli -p 6379 LANCE.EMBEDDING CONFIG SET textset PROVIDER openai MODEL text-embedding-3-small \ | ||||
|   PARAM use_azure true \ | ||||
|   PARAM azure_endpoint https://myresource.openai.azure.com \ | ||||
|   PARAM azure_deployment my-embed-deploy \ | ||||
|   PARAM azure_api_version 2024-02-15 \ | ||||
|   PARAM dim 512 | ||||
| ``` | ||||
| Notes: | ||||
| - Ensure dataset DIM matches the configured embedding dimension. | ||||
| - Lance is only available for non-admin databases (db_id >= 1). | ||||
| - On Lance DBs, only LANCE.* and basic control commands are allowed. | ||||
							
								
								
									
										259
									
								
								src/cmd.rs
									
									
									
									
									
								
							
							
						
						
									
										259
									
								
								src/cmd.rs
									
									
									
									
									
								
							| @@ -1,4 +1,5 @@ | ||||
| use crate::{error::DBError, protocol::Protocol, server::Server, embedding::{EmbeddingConfig, EmbeddingProvider}}; | ||||
| use base64::{engine::general_purpose, Engine as _}; | ||||
| use tokio::time::{timeout, Duration}; | ||||
| use futures::future::select_all; | ||||
|  | ||||
| @@ -145,6 +146,22 @@ pub enum Cmd { | ||||
|         filter: Option<String>, | ||||
|         return_fields: Option<Vec<String>>, | ||||
|     }, | ||||
|     // Image-first commands (no user-provided vectors) | ||||
|     LanceStoreImage { | ||||
|         name: String, | ||||
|         id: String, | ||||
|         uri: Option<String>, | ||||
|         bytes_b64: Option<String>, | ||||
|         meta: Vec<(String, String)>, | ||||
|     }, | ||||
|     LanceSearchImage { | ||||
|         name: String, | ||||
|         k: usize, | ||||
|         uri: Option<String>, | ||||
|         bytes_b64: Option<String>, | ||||
|         filter: Option<String>, | ||||
|         return_fields: Option<Vec<String>>, | ||||
|     }, | ||||
|     LanceCreateIndex { | ||||
|         name: String, | ||||
|         index_type: String, | ||||
| @@ -903,6 +920,46 @@ impl Cmd { | ||||
|                             } | ||||
|                             Cmd::LanceStoreText { name, id, text, meta } | ||||
|                         } | ||||
|                         "lance.storeimage" => { | ||||
|                             // LANCE.STOREIMAGE name ID <id> (URI <uri> | BYTES <base64>) [META k v ...] | ||||
|                             if cmd.len() < 6 { | ||||
|                                 return Err(DBError("ERR LANCE.STOREIMAGE requires: name ID <id> (URI <uri> | BYTES <base64>) [META k v ...]".to_string())); | ||||
|                             } | ||||
|                             let name = cmd[1].clone(); | ||||
|                             let mut i = 2; | ||||
|                             if cmd[i].to_uppercase() != "ID" || i + 1 >= cmd.len() { | ||||
|                                 return Err(DBError("ERR LANCE.STOREIMAGE requires ID <id>".to_string())); | ||||
|                             } | ||||
|                             let id = cmd[i + 1].clone(); | ||||
|                             i += 2; | ||||
|  | ||||
|                             let mut uri_opt: Option<String> = None; | ||||
|                             let mut bytes_b64_opt: Option<String> = None; | ||||
|  | ||||
|                             if i < cmd.len() && cmd[i].to_uppercase() == "URI" { | ||||
|                                 if i + 1 >= cmd.len() { return Err(DBError("ERR LANCE.STOREIMAGE URI requires a value".to_string())); } | ||||
|                                 uri_opt = Some(cmd[i + 1].clone()); | ||||
|                                 i += 2; | ||||
|                             } else if i < cmd.len() && cmd[i].to_uppercase() == "BYTES" { | ||||
|                                 if i + 1 >= cmd.len() { return Err(DBError("ERR LANCE.STOREIMAGE BYTES requires a value".to_string())); } | ||||
|                                 bytes_b64_opt = Some(cmd[i + 1].clone()); | ||||
|                                 i += 2; | ||||
|                             } else { | ||||
|                                 return Err(DBError("ERR LANCE.STOREIMAGE requires either URI <uri> or BYTES <base64>".to_string())); | ||||
|                             } | ||||
|  | ||||
|                             // Parse optional META pairs | ||||
|                             let mut meta: Vec<(String, String)> = Vec::new(); | ||||
|                             if i < cmd.len() && cmd[i].to_uppercase() == "META" { | ||||
|                                 i += 1; | ||||
|                                 while i + 1 < cmd.len() { | ||||
|                                     meta.push((cmd[i].clone(), cmd[i + 1].clone())); | ||||
|                                     i += 2; | ||||
|                                 } | ||||
|                             } | ||||
|  | ||||
|                             Cmd::LanceStoreImage { name, id, uri: uri_opt, bytes_b64: bytes_b64_opt, meta } | ||||
|                         } | ||||
|                         "lance.search" => { | ||||
|                             // LANCE.SEARCH name K <k> QUERY <text> [FILTER expr] [RETURN n fields...] | ||||
|                             if cmd.len() < 6 { | ||||
| @@ -954,6 +1011,65 @@ impl Cmd { | ||||
|                             } | ||||
|                             Cmd::LanceSearchText { name, text, k, filter, return_fields } | ||||
|                         } | ||||
|                         "lance.searchimage" => { | ||||
|                             // LANCE.SEARCHIMAGE name K <k> (QUERYURI <uri> | QUERYBYTES <base64>) [FILTER expr] [RETURN n fields...] | ||||
|                             if cmd.len() < 6 { | ||||
|                                 return Err(DBError("ERR LANCE.SEARCHIMAGE requires: name K <k> (QUERYURI <uri> | QUERYBYTES <base64>) [FILTER expr] [RETURN n fields...]".to_string())); | ||||
|                             } | ||||
|                             let name = cmd[1].clone(); | ||||
|                             if cmd[2].to_uppercase() != "K" { | ||||
|                                 return Err(DBError("ERR LANCE.SEARCHIMAGE requires K <k>".to_string())); | ||||
|                             } | ||||
|                             let k: usize = cmd[3].parse().map_err(|_| DBError("ERR K must be an integer".to_string()))?; | ||||
|                             let mut i = 4; | ||||
|  | ||||
|                             let mut uri_opt: Option<String> = None; | ||||
|                             let mut bytes_b64_opt: Option<String> = None; | ||||
|  | ||||
|                             if i < cmd.len() && cmd[i].to_uppercase() == "QUERYURI" { | ||||
|                                 if i + 1 >= cmd.len() { return Err(DBError("ERR QUERYURI requires a value".to_string())); } | ||||
|                                 uri_opt = Some(cmd[i + 1].clone()); | ||||
|                                 i += 2; | ||||
|                             } else if i < cmd.len() && cmd[i].to_uppercase() == "QUERYBYTES" { | ||||
|                                 if i + 1 >= cmd.len() { return Err(DBError("ERR QUERYBYTES requires a value".to_string())); } | ||||
|                                 bytes_b64_opt = Some(cmd[i + 1].clone()); | ||||
|                                 i += 2; | ||||
|                             } else { | ||||
|                                 return Err(DBError("ERR LANCE.SEARCHIMAGE requires QUERYURI <uri> or QUERYBYTES <base64>".to_string())); | ||||
|                             } | ||||
|  | ||||
|                             let mut filter: Option<String> = None; | ||||
|                             let mut return_fields: Option<Vec<String>> = None; | ||||
|                             while i < cmd.len() { | ||||
|                                 match cmd[i].to_uppercase().as_str() { | ||||
|                                     "FILTER" => { | ||||
|                                         if i + 1 >= cmd.len() { | ||||
|                                             return Err(DBError("ERR FILTER requires an expression".to_string())); | ||||
|                                         } | ||||
|                                         filter = Some(cmd[i + 1].clone()); | ||||
|                                         i += 2; | ||||
|                                     } | ||||
|                                     "RETURN" => { | ||||
|                                         if i + 1 >= cmd.len() { | ||||
|                                             return Err(DBError("ERR RETURN requires field count".to_string())); | ||||
|                                         } | ||||
|                                         let n: usize = cmd[i + 1].parse().map_err(|_| DBError("ERR RETURN count must be integer".to_string()))?; | ||||
|                                         i += 2; | ||||
|                                         let mut fields = Vec::new(); | ||||
|                                         for _ in 0..n { | ||||
|                                             if i < cmd.len() { | ||||
|                                                 fields.push(cmd[i].clone()); | ||||
|                                                 i += 1; | ||||
|                                             } | ||||
|                                         } | ||||
|                                         return_fields = Some(fields); | ||||
|                                     } | ||||
|                                     _ => { i += 1; } | ||||
|                                 } | ||||
|                             } | ||||
|  | ||||
|                             Cmd::LanceSearchImage { name, k, uri: uri_opt, bytes_b64: bytes_b64_opt, filter, return_fields } | ||||
|                         } | ||||
|                         "lance.createindex" => { | ||||
|                             // LANCE.CREATEINDEX name TYPE t [PARAM k v ...] | ||||
|                             if cmd.len() < 4 || cmd[2].to_uppercase() != "TYPE" { | ||||
| @@ -1136,6 +1252,8 @@ impl Cmd { | ||||
|                 | Cmd::LanceCreate { .. } | ||||
|                 | Cmd::LanceStoreText { .. } | ||||
|                 | Cmd::LanceSearchText { .. } | ||||
|                 | Cmd::LanceStoreImage { .. } | ||||
|                 | Cmd::LanceSearchImage { .. } | ||||
|                 | Cmd::LanceEmbeddingConfigSet { .. } | ||||
|                 | Cmd::LanceEmbeddingConfigGet { .. } | ||||
|                 | Cmd::LanceCreateIndex { .. } | ||||
| @@ -1172,6 +1290,8 @@ impl Cmd { | ||||
|                 Cmd::LanceCreate { .. } | ||||
|                 | Cmd::LanceStoreText { .. } | ||||
|                 | Cmd::LanceSearchText { .. } | ||||
|                 | Cmd::LanceStoreImage { .. } | ||||
|                 | Cmd::LanceSearchImage { .. } | ||||
|                 | Cmd::LanceEmbeddingConfigSet { .. } | ||||
|                 | Cmd::LanceEmbeddingConfigGet { .. } | ||||
|                 | Cmd::LanceCreateIndex { .. } | ||||
| @@ -1421,6 +1541,145 @@ impl Cmd { | ||||
|                     Err(e) => Ok(Protocol::err(&e.0)), | ||||
|                 } | ||||
|             } | ||||
|  | ||||
|             // New: Image store | ||||
|             Cmd::LanceStoreImage { name, id, uri, bytes_b64, meta } => { | ||||
|                 if !server.has_write_permission() { | ||||
|                     return Ok(Protocol::err("ERR write permission denied")); | ||||
|                 } | ||||
|                 let use_uri = uri.is_some(); | ||||
|                 let use_b64 = bytes_b64.is_some(); | ||||
|                 if (use_uri && use_b64) || (!use_uri && !use_b64) { | ||||
|                     return Ok(Protocol::err("ERR Provide exactly one of URI or BYTES for LANCE.STOREIMAGE")); | ||||
|                 } | ||||
|                 let max_bytes: usize = std::env::var("HERODB_IMAGE_MAX_BYTES") | ||||
|                     .ok() | ||||
|                     .and_then(|s| s.parse::<u64>().ok()) | ||||
|                     .unwrap_or(10 * 1024 * 1024) as usize; | ||||
|  | ||||
|                 let media_uri_opt = if let Some(u) = uri.clone() { | ||||
|                     match server.fetch_image_bytes_from_uri(&u) { | ||||
|                         Ok(_) => {} | ||||
|                         Err(e) => return Ok(Protocol::err(&e.0)), | ||||
|                     } | ||||
|                     Some(u) | ||||
|                 } else { | ||||
|                     None | ||||
|                 }; | ||||
|  | ||||
|                 let bytes: Vec<u8> = if let Some(u) = uri { | ||||
|                     match server.fetch_image_bytes_from_uri(&u) { | ||||
|                         Ok(b) => b, | ||||
|                         Err(e) => return Ok(Protocol::err(&e.0)), | ||||
|                     } | ||||
|                 } else { | ||||
|                     let b64 = bytes_b64.unwrap_or_default(); | ||||
|                     let data = match general_purpose::STANDARD.decode(b64.as_bytes()) { | ||||
|                         Ok(d) => d, | ||||
|                         Err(e) => return Ok(Protocol::err(&format!("ERR base64 decode error: {}", e))), | ||||
|                     }; | ||||
|                     if data.len() > max_bytes { | ||||
|                         return Ok(Protocol::err(&format!("ERR image exceeds max allowed bytes {}", max_bytes))); | ||||
|                     } | ||||
|                     data | ||||
|                 }; | ||||
|  | ||||
|                 let img_embedder = match server.get_image_embedder_for(&name) { | ||||
|                     Ok(e) => e, | ||||
|                     Err(e) => return Ok(Protocol::err(&e.0)), | ||||
|                 }; | ||||
|                 let (tx, rx) = tokio::sync::oneshot::channel(); | ||||
|                 let emb_arc = img_embedder.clone(); | ||||
|                 let bytes_cl = bytes.clone(); | ||||
|                 std::thread::spawn(move || { | ||||
|                     let res = emb_arc.embed_image(&bytes_cl); | ||||
|                     let _ = tx.send(res); | ||||
|                 }); | ||||
|                 let vector = match rx.await { | ||||
|                     Ok(Ok(v)) => v, | ||||
|                     Ok(Err(e)) => return Ok(Protocol::err(&e.0)), | ||||
|                     Err(recv_err) => return Ok(Protocol::err(&format!("ERR embedding thread error: {}", recv_err))), | ||||
|                 }; | ||||
|  | ||||
|                 let meta_map: std::collections::HashMap<String, String> = meta.into_iter().collect(); | ||||
|                 match server.lance_store()?.store_vector_with_media( | ||||
|                     &name, | ||||
|                     &id, | ||||
|                     vector, | ||||
|                     meta_map, | ||||
|                     None, | ||||
|                     Some("image".to_string()), | ||||
|                     media_uri_opt, | ||||
|                 ).await { | ||||
|                     Ok(()) => Ok(Protocol::SimpleString("OK".to_string())), | ||||
|                     Err(e) => Ok(Protocol::err(&e.0)), | ||||
|                 } | ||||
|             } | ||||
|  | ||||
|             // New: Image search | ||||
|             Cmd::LanceSearchImage { name, k, uri, bytes_b64, filter, return_fields } => { | ||||
|                 let use_uri = uri.is_some(); | ||||
|                 let use_b64 = bytes_b64.is_some(); | ||||
|                 if (use_uri && use_b64) || (!use_uri && !use_b64) { | ||||
|                     return Ok(Protocol::err("ERR Provide exactly one of QUERYURI or QUERYBYTES for LANCE.SEARCHIMAGE")); | ||||
|                 } | ||||
|                 let max_bytes: usize = std::env::var("HERODB_IMAGE_MAX_BYTES") | ||||
|                     .ok() | ||||
|                     .and_then(|s| s.parse::<u64>().ok()) | ||||
|                     .unwrap_or(10 * 1024 * 1024) as usize; | ||||
|  | ||||
|                 let bytes: Vec<u8> = if let Some(u) = uri { | ||||
|                     match server.fetch_image_bytes_from_uri(&u) { | ||||
|                         Ok(b) => b, | ||||
|                         Err(e) => return Ok(Protocol::err(&e.0)), | ||||
|                     } | ||||
|                 } else { | ||||
|                     let b64 = bytes_b64.unwrap_or_default(); | ||||
|                     let data = match general_purpose::STANDARD.decode(b64.as_bytes()) { | ||||
|                         Ok(d) => d, | ||||
|                         Err(e) => return Ok(Protocol::err(&format!("ERR base64 decode error: {}", e))), | ||||
|                     }; | ||||
|                     if data.len() > max_bytes { | ||||
|                         return Ok(Protocol::err(&format!("ERR image exceeds max allowed bytes {}", max_bytes))); | ||||
|                     } | ||||
|                     data | ||||
|                 }; | ||||
|  | ||||
|                 let img_embedder = match server.get_image_embedder_for(&name) { | ||||
|                     Ok(e) => e, | ||||
|                     Err(e) => return Ok(Protocol::err(&e.0)), | ||||
|                 }; | ||||
|                 let (tx, rx) = tokio::sync::oneshot::channel(); | ||||
|                 std::thread::spawn(move || { | ||||
|                     let res = img_embedder.embed_image(&bytes); | ||||
|                     let _ = tx.send(res); | ||||
|                 }); | ||||
|                 let qv = match rx.await { | ||||
|                     Ok(Ok(v)) => v, | ||||
|                     Ok(Err(e)) => return Ok(Protocol::err(&e.0)), | ||||
|                     Err(recv_err) => return Ok(Protocol::err(&format!("ERR embedding thread error: {}", recv_err))), | ||||
|                 }; | ||||
|  | ||||
|                 match server.lance_store()?.search_vectors(&name, qv, k, filter, return_fields).await { | ||||
|                     Ok(results) => { | ||||
|                         let mut arr = Vec::new(); | ||||
|                         for (id, score, meta) in results { | ||||
|                             let mut meta_arr: Vec<Protocol> = Vec::new(); | ||||
|                             for (k, v) in meta { | ||||
|                                 meta_arr.push(Protocol::BulkString(k)); | ||||
|                                 meta_arr.push(Protocol::BulkString(v)); | ||||
|                             } | ||||
|                             arr.push(Protocol::Array(vec![ | ||||
|                                 Protocol::BulkString(id), | ||||
|                                 Protocol::BulkString(score.to_string()), | ||||
|                                 Protocol::Array(meta_arr), | ||||
|                             ])); | ||||
|                         } | ||||
|                         Ok(Protocol::Array(arr)) | ||||
|                     } | ||||
|                     Err(e) => Ok(Protocol::err(&e.0)), | ||||
|                 } | ||||
|             } | ||||
|             Cmd::LanceCreateIndex { name, index_type, params } => { | ||||
|                 if !server.has_write_permission() { | ||||
|                     return Ok(Protocol::err("ERR write permission denied")); | ||||
|   | ||||
| @@ -1,8 +1,8 @@ | ||||
| use chacha20poly1305::{ | ||||
|     aead::{Aead, KeyInit, OsRng}, | ||||
|     aead::{Aead, KeyInit}, | ||||
|     XChaCha20Poly1305, XNonce, | ||||
| }; | ||||
| use rand::RngCore; | ||||
| use rand::{rngs::OsRng, RngCore}; | ||||
| use sha2::{Digest, Sha256}; | ||||
|  | ||||
| const VERSION: u8 = 1; | ||||
| @@ -31,7 +31,7 @@ pub struct CryptoFactory { | ||||
| impl CryptoFactory { | ||||
|     /// Accepts any secret bytes; turns them into a 32-byte key (SHA-256). | ||||
|     pub fn new<S: AsRef<[u8]>>(secret: S) -> Self { | ||||
|         let mut h = Sha256::new(); | ||||
|         let mut h = Sha256::default(); | ||||
|         h.update(b"xchacha20poly1305-factory:v1"); // domain separation | ||||
|         h.update(secret.as_ref()); | ||||
|         let digest = h.finalize(); // 32 bytes | ||||
|   | ||||
| @@ -1,18 +1,4 @@ | ||||
| // Embedding abstraction and minimal providers. | ||||
| // | ||||
| // This module defines a provider-agnostic interface to produce vector embeddings | ||||
| // from text, so callers never need to supply vectors manually. It includes: | ||||
| // - Embedder trait | ||||
| // - EmbeddingProvider and EmbeddingConfig (serde-serializable) | ||||
| // - TestHashEmbedder: deterministic, CPU-only, no-network embedder suitable for CI | ||||
| // - Factory create_embedder(..) to instantiate an embedder from config | ||||
| // | ||||
| // Integration plan: | ||||
| // - Server will resolve per-dataset EmbeddingConfig from sidecar JSON and cache Arc<dyn Embedder> | ||||
| // - LanceStore will call embedder.embed(text) then persist id, vector, text, meta | ||||
| // | ||||
| // Note: Real LanceDB-backed embedding providers can be added by implementing Embedder | ||||
| // and extending create_embedder(..). This file keeps no direct dependency on LanceDB. | ||||
|  | ||||
| use std::collections::HashMap; | ||||
| use std::sync::Arc; | ||||
|   | ||||
| @@ -112,6 +112,8 @@ impl LanceStore { | ||||
|             Field::new("id", DataType::Utf8, false), | ||||
|             Self::vector_field(dim), | ||||
|             Field::new("text", DataType::Utf8, true), | ||||
|             Field::new("media_type", DataType::Utf8, true), | ||||
|             Field::new("media_uri", DataType::Utf8, true), | ||||
|             Field::new("meta", DataType::Utf8, true), | ||||
|         ])) | ||||
|     } | ||||
| @@ -121,6 +123,8 @@ impl LanceStore { | ||||
|         vector: &[f32], | ||||
|         meta: &HashMap<String, String>, | ||||
|         text: Option<&str>, | ||||
|         media_type: Option<&str>, | ||||
|         media_uri: Option<&str>, | ||||
|         dim: i32, | ||||
|     ) -> Result<(Arc<Schema>, RecordBatch), DBError> { | ||||
|         if vector.len() as i32 != dim { | ||||
| @@ -156,6 +160,24 @@ impl LanceStore { | ||||
|         } | ||||
|         let text_arr = Arc::new(text_builder.finish()) as Arc<dyn Array>; | ||||
|  | ||||
|         // media_type column (optional) | ||||
|         let mut mt_builder = StringBuilder::new(); | ||||
|         if let Some(mt) = media_type { | ||||
|             mt_builder.append_value(mt); | ||||
|         } else { | ||||
|             mt_builder.append_null(); | ||||
|         } | ||||
|         let mt_arr = Arc::new(mt_builder.finish()) as Arc<dyn Array>; | ||||
|  | ||||
|         // media_uri column (optional) | ||||
|         let mut mu_builder = StringBuilder::new(); | ||||
|         if let Some(mu) = media_uri { | ||||
|             mu_builder.append_value(mu); | ||||
|         } else { | ||||
|             mu_builder.append_null(); | ||||
|         } | ||||
|         let mu_arr = Arc::new(mu_builder.finish()) as Arc<dyn Array>; | ||||
|  | ||||
|         // meta column (JSON string) | ||||
|         let meta_json = if meta.is_empty() { | ||||
|             None | ||||
| @@ -171,7 +193,7 @@ impl LanceStore { | ||||
|         let meta_arr = Arc::new(meta_builder.finish()) as Arc<dyn Array>; | ||||
|  | ||||
|         let batch = | ||||
|             RecordBatch::try_new(schema.clone(), vec![id_arr, vec_arr, text_arr, meta_arr]).map_err(|e| { | ||||
|             RecordBatch::try_new(schema.clone(), vec![id_arr, vec_arr, text_arr, mt_arr, mu_arr, meta_arr]).map_err(|e| { | ||||
|                 DBError(format!("RecordBatch build failed: {e}")) | ||||
|             })?; | ||||
|  | ||||
| @@ -207,10 +229,12 @@ impl LanceStore { | ||||
|         let mut list_builder = FixedSizeListBuilder::new(v_builder, dim_i32); | ||||
|         let empty_vec = Arc::new(list_builder.finish()) as Arc<dyn Array>; | ||||
|         let empty_text = Arc::new(StringArray::new_null(0)); | ||||
|         let empty_media_type = Arc::new(StringArray::new_null(0)); | ||||
|         let empty_media_uri = Arc::new(StringArray::new_null(0)); | ||||
|         let empty_meta = Arc::new(StringArray::new_null(0)); | ||||
|  | ||||
|         let empty_batch = | ||||
|             RecordBatch::try_new(schema.clone(), vec![empty_id, empty_vec, empty_text, empty_meta]) | ||||
|             RecordBatch::try_new(schema.clone(), vec![empty_id, empty_vec, empty_text, empty_media_type, empty_media_uri, empty_meta]) | ||||
|                 .map_err(|e| DBError(format!("Build empty batch failed: {e}")))?; | ||||
|  | ||||
|         let write_params = WriteParams { | ||||
| @@ -235,6 +259,21 @@ impl LanceStore { | ||||
|         vector: Vec<f32>, | ||||
|         meta: HashMap<String, String>, | ||||
|         text: Option<String>, | ||||
|     ) -> Result<(), DBError> { | ||||
|         // Delegate to media-aware path with no media fields | ||||
|         self.store_vector_with_media(name, id, vector, meta, text, None, None).await | ||||
|     } | ||||
|  | ||||
|     /// Store/Upsert a single vector with optional text and media fields (media_type/media_uri). | ||||
|     pub async fn store_vector_with_media( | ||||
|         &self, | ||||
|         name: &str, | ||||
|         id: &str, | ||||
|         vector: Vec<f32>, | ||||
|         meta: HashMap<String, String>, | ||||
|         text: Option<String>, | ||||
|         media_type: Option<String>, | ||||
|         media_uri: Option<String>, | ||||
|     ) -> Result<(), DBError> { | ||||
|         let path = self.dataset_path(name); | ||||
|  | ||||
| @@ -248,7 +287,15 @@ impl LanceStore { | ||||
|                 .map_err(|_| DBError("Vector length too large".into()))? | ||||
|         }; | ||||
|  | ||||
|         let (schema, batch) = Self::build_one_row_batch(id, &vector, &meta, text.as_deref(), dim_i32)?; | ||||
|         let (schema, batch) = Self::build_one_row_batch( | ||||
|             id, | ||||
|             &vector, | ||||
|             &meta, | ||||
|             text.as_deref(), | ||||
|             media_type.as_deref(), | ||||
|             media_uri.as_deref(), | ||||
|             dim_i32, | ||||
|         )?; | ||||
|  | ||||
|         // If LanceDB table exists and provides delete, we can upsert by deleting same id | ||||
|         // Try best-effort delete; ignore errors to keep operation append-only on failure | ||||
| @@ -355,21 +402,36 @@ impl LanceStore { | ||||
|             .await | ||||
|             .map_err(|e| DBError(format!("Open dataset failed: {}", e)))?; | ||||
|  | ||||
|         // Build scanner with projection; filter if provided | ||||
|         // Build scanner with projection; we project needed fields and filter client-side to support meta keys | ||||
|         let mut scan = ds.scan(); | ||||
|         if let Err(e) = scan.project(&["id", "vector", "meta"]) { | ||||
|         if let Err(e) = scan.project(&["id", "vector", "meta", "text", "media_type", "media_uri"]) { | ||||
|             return Err(DBError(format!("Project failed: {}", e))); | ||||
|         } | ||||
|         if let Some(pred) = filter { | ||||
|             if let Err(e) = scan.filter(&pred) { | ||||
|                 return Err(DBError(format!("Filter failed: {}", e))); | ||||
|             } | ||||
|         } | ||||
|         // Note: we no longer push down filter to Lance to allow filtering on meta fields client-side. | ||||
|  | ||||
|         let mut stream = scan | ||||
|             .try_into_stream() | ||||
|             .await | ||||
|             .map_err(|e| DBError(format!("Scan stream failed: {}", e)))?; | ||||
|          | ||||
|         // Parse simple equality clause from filter for client-side filtering (supports one `key = 'value'`) | ||||
|         let clause = filter.as_ref().and_then(|s| { | ||||
|             fn parse_eq(s: &str) -> Option<(String, String)> { | ||||
|                 let s = s.trim(); | ||||
|                 let pos = s.find('=').or_else(|| s.find(" = "))?; | ||||
|                 let (k, vraw) = s.split_at(pos); | ||||
|                 let mut v = vraw.trim_start_matches('=').trim(); | ||||
|                 if (v.starts_with('\'') && v.ends_with('\'')) || (v.starts_with('"') && v.ends_with('"')) { | ||||
|                     if v.len() >= 2 { | ||||
|                         v = &v[1..v.len()-1]; | ||||
|                     } | ||||
|                 } | ||||
|                 let key = k.trim().trim_matches('"').trim_matches('\'').to_string(); | ||||
|                 if key.is_empty() { return None; } | ||||
|                 Some((key, v.to_string())) | ||||
|             } | ||||
|             parse_eq(s) | ||||
|         }); | ||||
|  | ||||
|         // Maintain a max-heap with reverse ordering to keep top-k smallest distances | ||||
|         #[derive(Debug)] | ||||
| @@ -412,20 +474,18 @@ impl LanceStore { | ||||
|             let meta_arr = batch | ||||
|                 .column_by_name("meta") | ||||
|                 .map(|a| a.as_string::<i32>().clone()); | ||||
|             let text_arr = batch | ||||
|                 .column_by_name("text") | ||||
|                 .map(|a| a.as_string::<i32>().clone()); | ||||
|             let mt_arr = batch | ||||
|                 .column_by_name("media_type") | ||||
|                 .map(|a| a.as_string::<i32>().clone()); | ||||
|             let mu_arr = batch | ||||
|                 .column_by_name("media_uri") | ||||
|                 .map(|a| a.as_string::<i32>().clone()); | ||||
|  | ||||
|             for i in 0..batch.num_rows() { | ||||
|                 // Compute L2 distance | ||||
|                 let val = vec_arr.value(i); | ||||
|                 let prim = val.as_primitive::<Float32Type>(); | ||||
|                 let mut dist: f32 = 0.0; | ||||
|                 let plen = prim.len(); | ||||
|                 for j in 0..plen { | ||||
|                     let r = prim.value(j); | ||||
|                     let d = query[j] - r; | ||||
|                     dist += d * d; | ||||
|                 } | ||||
|  | ||||
|                 // Parse id | ||||
|                 // Extract id | ||||
|                 let id_val = id_arr.value(i).to_string(); | ||||
|  | ||||
|                 // Parse meta JSON if present | ||||
| @@ -439,26 +499,54 @@ impl LanceStore { | ||||
|                                     meta.insert(k, vs.to_string()); | ||||
|                                 } else if v.is_number() || v.is_boolean() { | ||||
|                                     meta.insert(k, v.to_string()); | ||||
|                                 } else { | ||||
|                                     // skip complex entries | ||||
|                                 } | ||||
|                             } | ||||
|                         } | ||||
|                     } | ||||
|                 } | ||||
|  | ||||
|                 // Evaluate simple equality filter if provided (supports one clause) | ||||
|                 let passes = if let Some((ref key, ref val)) = clause { | ||||
|                     let candidate = match key.as_str() { | ||||
|                         "id" => Some(id_val.clone()), | ||||
|                         "text" => text_arr.as_ref().and_then(|col| if col.is_null(i) { None } else { Some(col.value(i).to_string()) }), | ||||
|                         "media_type" => mt_arr.as_ref().and_then(|col| if col.is_null(i) { None } else { Some(col.value(i).to_string()) }), | ||||
|                         "media_uri" => mu_arr.as_ref().and_then(|col| if col.is_null(i) { None } else { Some(col.value(i).to_string()) }), | ||||
|                         _ => meta.get(key).cloned(), | ||||
|                     }; | ||||
|                     match candidate { | ||||
|                         Some(cv) => cv == *val, | ||||
|                         None => false, | ||||
|                     } | ||||
|                 } else { true }; | ||||
|                 if !passes { | ||||
|                     continue; | ||||
|                 } | ||||
|  | ||||
|                 // Compute L2 distance | ||||
|                 let val = vec_arr.value(i); | ||||
|                 let prim = val.as_primitive::<Float32Type>(); | ||||
|                 let mut dist: f32 = 0.0; | ||||
|                 let plen = prim.len(); | ||||
|                 for j in 0..plen { | ||||
|                     let r = prim.value(j); | ||||
|                     let d = query[j] - r; | ||||
|                     dist += d * d; | ||||
|                 } | ||||
|  | ||||
|                 // Apply return_fields on meta | ||||
|                 let mut meta_out = meta; | ||||
|                 if let Some(fields) = &return_fields { | ||||
|                     let mut filtered = HashMap::new(); | ||||
|                     for f in fields { | ||||
|                         if let Some(val) = meta.get(f) { | ||||
|                         if let Some(val) = meta_out.get(f) { | ||||
|                             filtered.insert(f.clone(), val.clone()); | ||||
|                         } | ||||
|                     } | ||||
|                     meta = filtered; | ||||
|                     meta_out = filtered; | ||||
|                 } | ||||
|  | ||||
|                 let hit = Hit { dist, id: id_val, meta }; | ||||
|                 let hit = Hit { dist, id: id_val, meta: meta_out }; | ||||
|  | ||||
|                 if heap.len() < k { | ||||
|                     heap.push(hit); | ||||
|   | ||||
							
								
								
									
										228
									
								
								src/rpc.rs
									
									
									
									
									
								
							
							
						
						
									
										228
									
								
								src/rpc.rs
									
									
									
									
									
								
							| @@ -10,6 +10,7 @@ use crate::server::Server; | ||||
| use crate::options::DBOption; | ||||
| use crate::admin_meta; | ||||
| use crate::embedding::{EmbeddingConfig, EmbeddingProvider}; | ||||
| use base64::{engine::general_purpose, Engine as _}; | ||||
|  | ||||
| /// Database backend types | ||||
| #[derive(Debug, Clone, Serialize, Deserialize)] | ||||
| @@ -282,6 +283,33 @@ pub trait Rpc { | ||||
|         filter: Option<String>, | ||||
|         return_fields: Option<Vec<String>>, | ||||
|     ) -> RpcResult<serde_json::Value>; | ||||
|  | ||||
|     // ----- Image-first endpoints (no user-provided vectors) ----- | ||||
|  | ||||
|     /// Store an image; exactly one of uri or bytes_b64 must be provided. | ||||
|     #[method(name = "lanceStoreImage")] | ||||
|     async fn lance_store_image( | ||||
|         &self, | ||||
|         db_id: u64, | ||||
|         name: String, | ||||
|         id: String, | ||||
|         uri: Option<String>, | ||||
|         bytes_b64: Option<String>, | ||||
|         meta: Option<HashMap<String, String>>, | ||||
|     ) -> RpcResult<bool>; | ||||
|  | ||||
|     /// Search using an image query; exactly one of uri or bytes_b64 must be provided. | ||||
|     #[method(name = "lanceSearchImage")] | ||||
|     async fn lance_search_image( | ||||
|         &self, | ||||
|         db_id: u64, | ||||
|         name: String, | ||||
|         k: usize, | ||||
|         uri: Option<String>, | ||||
|         bytes_b64: Option<String>, | ||||
|         filter: Option<String>, | ||||
|         return_fields: Option<Vec<String>>, | ||||
|     ) -> RpcResult<serde_json::Value>; | ||||
| } | ||||
|  | ||||
| /// RPC Server implementation | ||||
| @@ -1131,4 +1159,204 @@ impl RpcServer for RpcServerImpl { | ||||
|  | ||||
|         Ok(serde_json::json!({ "results": json_results })) | ||||
|     } | ||||
|  | ||||
|     // ----- New image-first Lance RPC implementations ----- | ||||
|  | ||||
|     async fn lance_store_image( | ||||
|         &self, | ||||
|         db_id: u64, | ||||
|         name: String, | ||||
|         id: String, | ||||
|         uri: Option<String>, | ||||
|         bytes_b64: Option<String>, | ||||
|         meta: Option<HashMap<String, String>>, | ||||
|     ) -> RpcResult<bool> { | ||||
|         let server = self.get_or_create_server(db_id).await?; | ||||
|         if db_id == 0 { | ||||
|             return Err(jsonrpsee::types::ErrorObjectOwned::owned(-32000, "Lance not allowed on DB 0", None::<()>)); | ||||
|         } | ||||
|         if !matches!(server.option.backend, crate::options::BackendType::Lance) { | ||||
|             return Err(jsonrpsee::types::ErrorObjectOwned::owned(-32000, "DB backend is not Lance", None::<()>)); | ||||
|         } | ||||
|         if !server.has_write_permission() { | ||||
|             return Err(jsonrpsee::types::ErrorObjectOwned::owned(-32000, "write permission denied", None::<()>)); | ||||
|         } | ||||
|  | ||||
|         // Validate exactly one of uri or bytes_b64 | ||||
|         let (use_uri, use_b64) = (uri.is_some(), bytes_b64.is_some()); | ||||
|         if (use_uri && use_b64) || (!use_uri && !use_b64) { | ||||
|             return Err(jsonrpsee::types::ErrorObjectOwned::owned( | ||||
|                 -32000, | ||||
|                 "Provide exactly one of 'uri' or 'bytes_b64'", | ||||
|                 None::<()>, | ||||
|             )); | ||||
|         } | ||||
|  | ||||
|         // Acquire image bytes (with caps) | ||||
|         let max_bytes: usize = std::env::var("HERODB_IMAGE_MAX_BYTES") | ||||
|             .ok() | ||||
|             .and_then(|s| s.parse::<u64>().ok()) | ||||
|             .unwrap_or(10 * 1024 * 1024) as usize; | ||||
|  | ||||
|         let (bytes, media_uri_opt) = if let Some(u) = uri.clone() { | ||||
|             let data = server | ||||
|                 .fetch_image_bytes_from_uri(&u) | ||||
|                 .map_err(|e| jsonrpsee::types::ErrorObjectOwned::owned(-32000, e.0, None::<()>))?; | ||||
|             (data, Some(u)) | ||||
|         } else { | ||||
|             let b64 = bytes_b64.unwrap_or_default(); | ||||
|             let data = general_purpose::STANDARD | ||||
|                 .decode(b64.as_bytes()) | ||||
|                 .map_err(|e| jsonrpsee::types::ErrorObjectOwned::owned(-32000, format!("base64 decode error: {}", e), None::<()>))?; | ||||
|             if data.len() > max_bytes { | ||||
|                 return Err(jsonrpsee::types::ErrorObjectOwned::owned( | ||||
|                     -32000, | ||||
|                     format!("Image exceeds max allowed bytes {}", max_bytes), | ||||
|                     None::<()>, | ||||
|                 )); | ||||
|             } | ||||
|             (data, None) | ||||
|         }; | ||||
|  | ||||
|         // Resolve image embedder and embed on a plain OS thread | ||||
|         let img_embedder = server | ||||
|             .get_image_embedder_for(&name) | ||||
|             .map_err(|e| jsonrpsee::types::ErrorObjectOwned::owned(-32000, e.0, None::<()>))?; | ||||
|         let (tx, rx) = tokio::sync::oneshot::channel(); | ||||
|         let emb_arc = img_embedder.clone(); | ||||
|         let bytes_cl = bytes.clone(); | ||||
|         std::thread::spawn(move || { | ||||
|             let res = emb_arc.embed_image(&bytes_cl); | ||||
|             let _ = tx.send(res); | ||||
|         }); | ||||
|         let vector = match rx.await { | ||||
|             Ok(Ok(v)) => v, | ||||
|             Ok(Err(e)) => return Err(jsonrpsee::types::ErrorObjectOwned::owned(-32000, e.0, None::<()>)), | ||||
|             Err(recv_err) => { | ||||
|                 return Err(jsonrpsee::types::ErrorObjectOwned::owned( | ||||
|                     -32000, | ||||
|                     format!("embedding thread error: {}", recv_err), | ||||
|                     None::<()>, | ||||
|                 )) | ||||
|             } | ||||
|         }; | ||||
|  | ||||
|         // Store vector with media fields | ||||
|         server | ||||
|             .lance_store() | ||||
|             .map_err(|e| jsonrpsee::types::ErrorObjectOwned::owned(-32000, e.0, None::<()>))? | ||||
|             .store_vector_with_media( | ||||
|                 &name, | ||||
|                 &id, | ||||
|                 vector, | ||||
|                 meta.unwrap_or_default(), | ||||
|                 None, | ||||
|                 Some("image".to_string()), | ||||
|                 media_uri_opt, | ||||
|             ) | ||||
|             .await | ||||
|             .map_err(|e| jsonrpsee::types::ErrorObjectOwned::owned(-32000, e.0, None::<()>))?; | ||||
|  | ||||
|         Ok(true) | ||||
|     } | ||||
|  | ||||
|     async fn lance_search_image( | ||||
|         &self, | ||||
|         db_id: u64, | ||||
|         name: String, | ||||
|         k: usize, | ||||
|         uri: Option<String>, | ||||
|         bytes_b64: Option<String>, | ||||
|         filter: Option<String>, | ||||
|         return_fields: Option<Vec<String>>, | ||||
|     ) -> RpcResult<serde_json::Value> { | ||||
|         let server = self.get_or_create_server(db_id).await?; | ||||
|         if db_id == 0 { | ||||
|             return Err(jsonrpsee::types::ErrorObjectOwned::owned(-32000, "Lance not allowed on DB 0", None::<()>)); | ||||
|         } | ||||
|         if !matches!(server.option.backend, crate::options::BackendType::Lance) { | ||||
|             return Err(jsonrpsee::types::ErrorObjectOwned::owned(-32000, "DB backend is not Lance", None::<()>)); | ||||
|         } | ||||
|         if !server.has_read_permission() { | ||||
|             return Err(jsonrpsee::types::ErrorObjectOwned::owned(-32000, "read permission denied", None::<()>)); | ||||
|         } | ||||
|  | ||||
|         // Validate exactly one of uri or bytes_b64 | ||||
|         let (use_uri, use_b64) = (uri.is_some(), bytes_b64.is_some()); | ||||
|         if (use_uri && use_b64) || (!use_uri && !use_b64) { | ||||
|             return Err(jsonrpsee::types::ErrorObjectOwned::owned( | ||||
|                 -32000, | ||||
|                 "Provide exactly one of 'uri' or 'bytes_b64'", | ||||
|                 None::<()>, | ||||
|             )); | ||||
|         } | ||||
|  | ||||
|         // Acquire image bytes for query (with caps) | ||||
|         let max_bytes: usize = std::env::var("HERODB_IMAGE_MAX_BYTES") | ||||
|             .ok() | ||||
|             .and_then(|s| s.parse::<u64>().ok()) | ||||
|             .unwrap_or(10 * 1024 * 1024) as usize; | ||||
|  | ||||
|         let bytes = if let Some(u) = uri { | ||||
|             server | ||||
|                 .fetch_image_bytes_from_uri(&u) | ||||
|                 .map_err(|e| jsonrpsee::types::ErrorObjectOwned::owned(-32000, e.0, None::<()>))? | ||||
|         } else { | ||||
|             let b64 = bytes_b64.unwrap_or_default(); | ||||
|             let data = general_purpose::STANDARD | ||||
|                 .decode(b64.as_bytes()) | ||||
|                 .map_err(|e| jsonrpsee::types::ErrorObjectOwned::owned(-32000, format!("base64 decode error: {}", e), None::<()>))?; | ||||
|             if data.len() > max_bytes { | ||||
|                 return Err(jsonrpsee::types::ErrorObjectOwned::owned( | ||||
|                     -32000, | ||||
|                     format!("Image exceeds max allowed bytes {}", max_bytes), | ||||
|                     None::<()>, | ||||
|                 )); | ||||
|             } | ||||
|             data | ||||
|         }; | ||||
|  | ||||
|         // Resolve image embedder and embed on OS thread | ||||
|         let img_embedder = server | ||||
|             .get_image_embedder_for(&name) | ||||
|             .map_err(|e| jsonrpsee::types::ErrorObjectOwned::owned(-32000, e.0, None::<()>))?; | ||||
|         let (tx, rx) = tokio::sync::oneshot::channel(); | ||||
|         let emb_arc = img_embedder.clone(); | ||||
|         std::thread::spawn(move || { | ||||
|             let res = emb_arc.embed_image(&bytes); | ||||
|             let _ = tx.send(res); | ||||
|         }); | ||||
|         let qv = match rx.await { | ||||
|             Ok(Ok(v)) => v, | ||||
|             Ok(Err(e)) => return Err(jsonrpsee::types::ErrorObjectOwned::owned(-32000, e.0, None::<()>)), | ||||
|             Err(recv_err) => { | ||||
|                 return Err(jsonrpsee::types::ErrorObjectOwned::owned( | ||||
|                     -32000, | ||||
|                     format!("embedding thread error: {}", recv_err), | ||||
|                     None::<()>, | ||||
|                 )) | ||||
|             } | ||||
|         }; | ||||
|  | ||||
|         // KNN search and return results | ||||
|         let results = server | ||||
|             .lance_store() | ||||
|             .map_err(|e| jsonrpsee::types::ErrorObjectOwned::owned(-32000, e.0, None::<()>))? | ||||
|             .search_vectors(&name, qv, k, filter, return_fields) | ||||
|             .await | ||||
|             .map_err(|e| jsonrpsee::types::ErrorObjectOwned::owned(-32000, e.0, None::<()>))?; | ||||
|  | ||||
|         let json_results: Vec<serde_json::Value> = results | ||||
|             .into_iter() | ||||
|             .map(|(id, score, meta)| { | ||||
|                 serde_json::json!({ | ||||
|                     "id": id, | ||||
|                     "score": score, | ||||
|                     "meta": meta, | ||||
|                 }) | ||||
|             }) | ||||
|             .collect(); | ||||
|  | ||||
|         Ok(serde_json::json!({ "results": json_results })) | ||||
|     } | ||||
| } | ||||
| @@ -15,8 +15,11 @@ use crate::storage_trait::StorageBackend; | ||||
| use crate::admin_meta; | ||||
|  | ||||
| // Embeddings: config and cache | ||||
| use crate::embedding::{EmbeddingConfig, create_embedder, Embedder}; | ||||
| use crate::embedding::{EmbeddingConfig, create_embedder, Embedder, create_image_embedder, ImageEmbedder}; | ||||
| use serde_json; | ||||
| use ureq::{Agent, AgentBuilder}; | ||||
| use std::time::Duration; | ||||
| use std::io::Read; | ||||
|  | ||||
| #[derive(Clone)] | ||||
| pub struct Server { | ||||
| @@ -33,9 +36,12 @@ pub struct Server { | ||||
|     // Per-DB Lance stores (vector DB), keyed by db_id | ||||
|     pub lance_stores: Arc<std::sync::RwLock<HashMap<u64, Arc<crate::lance_store::LanceStore>>>>, | ||||
|  | ||||
|     // Per-(db_id, dataset) embedder cache | ||||
|     // Per-(db_id, dataset) embedder cache (text) | ||||
|     pub embedders: Arc<std::sync::RwLock<HashMap<(u64, String), Arc<dyn Embedder>>>>, | ||||
|  | ||||
|     // Per-(db_id, dataset) image embedder cache (image) | ||||
|     pub image_embedders: Arc<std::sync::RwLock<HashMap<(u64, String), Arc<dyn ImageEmbedder>>>>, | ||||
|      | ||||
|     // BLPOP waiter registry: per (db_index, key) FIFO of waiters | ||||
|     pub list_waiters: Arc<Mutex<HashMap<u64, HashMap<String, Vec<Waiter>>>>>, | ||||
|     pub waiter_seq: Arc<AtomicU64>, | ||||
| @@ -66,6 +72,7 @@ impl Server { | ||||
|             search_indexes: Arc::new(std::sync::RwLock::new(HashMap::new())), | ||||
|             lance_stores: Arc::new(std::sync::RwLock::new(HashMap::new())), | ||||
|             embedders: Arc::new(std::sync::RwLock::new(HashMap::new())), | ||||
|             image_embedders: Arc::new(std::sync::RwLock::new(HashMap::new())), | ||||
|             list_waiters: Arc::new(Mutex::new(HashMap::new())), | ||||
|             waiter_seq: Arc::new(AtomicU64::new(1)), | ||||
|         } | ||||
| @@ -189,6 +196,10 @@ impl Server { | ||||
|             let mut map = self.embedders.write().unwrap(); | ||||
|             map.remove(&(self.selected_db, dataset.to_string())); | ||||
|         } | ||||
|         { | ||||
|             let mut map_img = self.image_embedders.write().unwrap(); | ||||
|             map_img.remove(&(self.selected_db, dataset.to_string())); | ||||
|         } | ||||
|         Ok(()) | ||||
|     } | ||||
|  | ||||
| @@ -233,6 +244,88 @@ impl Server { | ||||
|         Ok(emb) | ||||
|     } | ||||
|  | ||||
|     /// Resolve or build an IMAGE embedder for (db_id, dataset). Caches instance. | ||||
|     pub fn get_image_embedder_for(&self, dataset: &str) -> Result<Arc<dyn ImageEmbedder>, DBError> { | ||||
|         if self.selected_db == 0 { | ||||
|             return Err(DBError("Lance not available on admin DB 0".to_string())); | ||||
|         } | ||||
|         // Fast path | ||||
|         { | ||||
|             let map = self.image_embedders.read().unwrap(); | ||||
|             if let Some(e) = map.get(&(self.selected_db, dataset.to_string())) { | ||||
|                 return Ok(e.clone()); | ||||
|             } | ||||
|         } | ||||
|         // Load config and instantiate | ||||
|         let cfg = self.get_dataset_embedding_config(dataset)?; | ||||
|         let emb = create_image_embedder(&cfg)?; | ||||
|         { | ||||
|             let mut map = self.image_embedders.write().unwrap(); | ||||
|             map.insert((self.selected_db, dataset.to_string()), emb.clone()); | ||||
|         } | ||||
|         Ok(emb) | ||||
|     } | ||||
|  | ||||
|     /// Download image bytes from a URI with safety checks (size, timeout, content-type, optional host allowlist). | ||||
|     /// Env overrides: | ||||
|     /// - HERODB_IMAGE_MAX_BYTES (u64, default 10485760) | ||||
|     /// - HERODB_IMAGE_FETCH_TIMEOUT_SECS (u64, default 30) | ||||
|     /// - HERODB_IMAGE_ALLOWED_HOSTS (comma-separated, optional) | ||||
|     pub fn fetch_image_bytes_from_uri(&self, uri: &str) -> Result<Vec<u8>, DBError> { | ||||
|         // Basic scheme validation | ||||
|         if !(uri.starts_with("http://") || uri.starts_with("https://")) { | ||||
|             return Err(DBError("Only http(s) URIs are supported for image fetch".into())); | ||||
|         } | ||||
|         // Parse host (naive) for allowlist check | ||||
|         let host = { | ||||
|             let after_scheme = match uri.find("://") { | ||||
|                 Some(i) => &uri[i + 3..], | ||||
|                 None => uri, | ||||
|             }; | ||||
|             let end = after_scheme.find('/').unwrap_or(after_scheme.len()); | ||||
|             let host_port = &after_scheme[..end]; | ||||
|             host_port.split('@').last().unwrap_or(host_port).split(':').next().unwrap_or(host_port).to_string() | ||||
|         }; | ||||
|  | ||||
|         let max_bytes: u64 = std::env::var("HERODB_IMAGE_MAX_BYTES").ok().and_then(|s| s.parse::<u64>().ok()).unwrap_or(10 * 1024 * 1024); | ||||
|         let timeout_secs: u64 = std::env::var("HERODB_IMAGE_FETCH_TIMEOUT_SECS").ok().and_then(|s| s.parse::<u64>().ok()).unwrap_or(30); | ||||
|         let allowed_hosts_env = std::env::var("HERODB_IMAGE_ALLOWED_HOSTS").ok(); | ||||
|         if let Some(allow) = allowed_hosts_env { | ||||
|             if !allow.split(',').map(|s| s.trim()).filter(|s| !s.is_empty()).any(|h| h.eq_ignore_ascii_case(&host)) { | ||||
|                 return Err(DBError(format!("Host '{}' not allowed for image fetch (HERODB_IMAGE_ALLOWED_HOSTS)", host))); | ||||
|             } | ||||
|         } | ||||
|  | ||||
|         let agent: Agent = AgentBuilder::new() | ||||
|             .timeout_read(Duration::from_secs(timeout_secs)) | ||||
|             .timeout_write(Duration::from_secs(timeout_secs)) | ||||
|             .build(); | ||||
|  | ||||
|         let resp = agent.get(uri).call().map_err(|e| DBError(format!("HTTP GET failed: {}", e)))?; | ||||
|         // Validate content-type | ||||
|         let ctype = resp.header("Content-Type").unwrap_or(""); | ||||
|         let ctype_main = ctype.split(';').next().unwrap_or("").trim().to_ascii_lowercase(); | ||||
|         if !ctype_main.starts_with("image/") { | ||||
|             return Err(DBError(format!("Remote content-type '{}' is not image/*", ctype))); | ||||
|         } | ||||
|  | ||||
|         // Read with cap | ||||
|         let mut reader = resp.into_reader(); | ||||
|         let mut buf: Vec<u8> = Vec::with_capacity(8192); | ||||
|         let mut tmp = [0u8; 8192]; | ||||
|         let mut total: u64 = 0; | ||||
|         loop { | ||||
|             let n = reader.read(&mut tmp).map_err(|e| DBError(format!("Read error: {}", e)))?; | ||||
|             if n == 0 { break; } | ||||
|             total += n as u64; | ||||
|             if total > max_bytes { | ||||
|                 return Err(DBError(format!("Image exceeds max allowed bytes {}", max_bytes))); | ||||
|             } | ||||
|             buf.extend_from_slice(&tmp[..n]); | ||||
|         } | ||||
|         Ok(buf) | ||||
|     } | ||||
|  | ||||
|     /// Check if current permissions allow read operations | ||||
|     pub fn has_read_permission(&self) -> bool { | ||||
|         // If an explicit permission is set for this connection, honor it. | ||||
|   | ||||
		Reference in New Issue
	
	Block a user