fixed a few bugs related to vector embedding + added additional end to end documentation to showcase local and external embedders step-by-step + added example mock embedder python script
This commit is contained in:
112
src/cmd.rs
112
src/cmd.rs
@@ -1,4 +1,4 @@
|
||||
use crate::{error::DBError, protocol::Protocol, server::Server, embedding::{EmbeddingConfig, EmbeddingProvider}};
|
||||
use crate::{error::DBError, protocol::Protocol, server::Server, embedding::EmbeddingConfig};
|
||||
use base64::{engine::general_purpose, Engine as _};
|
||||
use tokio::time::{timeout, Duration};
|
||||
use futures::future::select_all;
|
||||
@@ -170,9 +170,7 @@ pub enum Cmd {
|
||||
// Embedding configuration per dataset
|
||||
LanceEmbeddingConfigSet {
|
||||
name: String,
|
||||
provider: String,
|
||||
model: String,
|
||||
params: Vec<(String, String)>,
|
||||
config: EmbeddingConfig,
|
||||
},
|
||||
LanceEmbeddingConfigGet {
|
||||
name: String,
|
||||
@@ -1089,20 +1087,25 @@ impl Cmd {
|
||||
Cmd::LanceCreateIndex { name, index_type, params }
|
||||
}
|
||||
"lance.embedding" => {
|
||||
// LANCE.EMBEDDING CONFIG SET name PROVIDER p MODEL m [PARAM k v ...]
|
||||
// LANCE.EMBEDDING CONFIG SET name PROVIDER p MODEL m DIM d [ENDPOINT url] [HEADER k v]... [TIMEOUTMS t]
|
||||
// LANCE.EMBEDDING CONFIG GET name
|
||||
if cmd.len() < 3 || cmd[1].to_uppercase() != "CONFIG" {
|
||||
return Err(DBError("ERR LANCE.EMBEDDING requires CONFIG subcommand".to_string()));
|
||||
}
|
||||
if cmd.len() >= 4 && cmd[2].to_uppercase() == "SET" {
|
||||
if cmd.len() < 8 {
|
||||
return Err(DBError("ERR LANCE.EMBEDDING CONFIG SET requires: SET name PROVIDER p MODEL m [PARAM k v ...]".to_string()));
|
||||
return Err(DBError("ERR LANCE.EMBEDDING CONFIG SET requires: SET name PROVIDER p MODEL m DIM d [ENDPOINT url] [HEADER k v]... [TIMEOUTMS t]".to_string()));
|
||||
}
|
||||
let name = cmd[3].clone();
|
||||
let mut i = 4;
|
||||
|
||||
let mut provider: Option<String> = None;
|
||||
let mut model: Option<String> = None;
|
||||
let mut params: Vec<(String, String)> = Vec::new();
|
||||
let mut dim: Option<usize> = None;
|
||||
let mut endpoint: Option<String> = None;
|
||||
let mut headers: std::collections::HashMap<String, String> = std::collections::HashMap::new();
|
||||
let mut timeout_ms: Option<u64> = None;
|
||||
|
||||
while i < cmd.len() {
|
||||
match cmd[i].to_uppercase().as_str() {
|
||||
"PROVIDER" => {
|
||||
@@ -1119,22 +1122,62 @@ impl Cmd {
|
||||
model = Some(cmd[i + 1].clone());
|
||||
i += 2;
|
||||
}
|
||||
"PARAM" => {
|
||||
i += 1;
|
||||
while i + 1 < cmd.len() {
|
||||
params.push((cmd[i].clone(), cmd[i + 1].clone()));
|
||||
i += 2;
|
||||
"DIM" => {
|
||||
if i + 1 >= cmd.len() {
|
||||
return Err(DBError("ERR DIM requires a value".to_string()));
|
||||
}
|
||||
let d: usize = cmd[i + 1].parse().map_err(|_| DBError("ERR DIM must be an integer".to_string()))?;
|
||||
dim = Some(d);
|
||||
i += 2;
|
||||
}
|
||||
"ENDPOINT" => {
|
||||
if i + 1 >= cmd.len() {
|
||||
return Err(DBError("ERR ENDPOINT requires a value".to_string()));
|
||||
}
|
||||
endpoint = Some(cmd[i + 1].clone());
|
||||
i += 2;
|
||||
}
|
||||
"HEADER" => {
|
||||
if i + 2 >= cmd.len() {
|
||||
return Err(DBError("ERR HEADER requires key and value".to_string()));
|
||||
}
|
||||
headers.insert(cmd[i + 1].clone(), cmd[i + 2].clone());
|
||||
i += 3;
|
||||
}
|
||||
"TIMEOUTMS" => {
|
||||
if i + 1 >= cmd.len() {
|
||||
return Err(DBError("ERR TIMEOUTMS requires a value".to_string()));
|
||||
}
|
||||
let t: u64 = cmd[i + 1].parse().map_err(|_| DBError("ERR TIMEOUTMS must be an integer".to_string()))?;
|
||||
timeout_ms = Some(t);
|
||||
i += 2;
|
||||
}
|
||||
_ => {
|
||||
// Unknown token; break to avoid infinite loop
|
||||
i += 1;
|
||||
}
|
||||
}
|
||||
}
|
||||
let provider = provider.ok_or_else(|| DBError("ERR missing PROVIDER".to_string()))?;
|
||||
|
||||
let provider_str = provider.ok_or_else(|| DBError("ERR missing PROVIDER".to_string()))?;
|
||||
let provider_enum = match provider_str.to_lowercase().as_str() {
|
||||
"openai" => crate::embedding::EmbeddingProvider::openai,
|
||||
"test" => crate::embedding::EmbeddingProvider::test,
|
||||
"image_test" | "imagetest" | "image-test" => crate::embedding::EmbeddingProvider::image_test,
|
||||
other => return Err(DBError(format!("ERR unsupported provider '{}'", other))),
|
||||
};
|
||||
let model = model.ok_or_else(|| DBError("ERR missing MODEL".to_string()))?;
|
||||
Cmd::LanceEmbeddingConfigSet { name, provider, model, params }
|
||||
let dim = dim.ok_or_else(|| DBError("ERR missing DIM".to_string()))?;
|
||||
|
||||
let config = EmbeddingConfig {
|
||||
provider: provider_enum,
|
||||
model,
|
||||
dim,
|
||||
endpoint,
|
||||
headers,
|
||||
timeout_ms,
|
||||
};
|
||||
|
||||
Cmd::LanceEmbeddingConfigSet { name, config }
|
||||
} else if cmd.len() == 4 && cmd[2].to_uppercase() == "GET" {
|
||||
let name = cmd[3].clone();
|
||||
Cmd::LanceEmbeddingConfigGet { name }
|
||||
@@ -1437,25 +1480,14 @@ impl Cmd {
|
||||
Err(e) => Ok(Protocol::err(&e.0)),
|
||||
}
|
||||
}
|
||||
Cmd::LanceEmbeddingConfigSet { name, provider, model, params } => {
|
||||
Cmd::LanceEmbeddingConfigSet { name, config } => {
|
||||
if !server.has_write_permission() {
|
||||
return Ok(Protocol::err("ERR write permission denied"));
|
||||
}
|
||||
// Map provider string to enum
|
||||
let p_lc = provider.to_lowercase();
|
||||
let prov = match p_lc.as_str() {
|
||||
"test-hash" | "testhash" => EmbeddingProvider::TestHash,
|
||||
"testimagehash" | "image-test-hash" | "imagetesthash" => EmbeddingProvider::ImageTestHash,
|
||||
"fastembed" | "lancefastembed" => EmbeddingProvider::LanceFastEmbed,
|
||||
"openai" | "lanceopenai" => EmbeddingProvider::LanceOpenAI,
|
||||
other => EmbeddingProvider::LanceOther(other.to_string()),
|
||||
};
|
||||
let cfg = EmbeddingConfig {
|
||||
provider: prov,
|
||||
model,
|
||||
params: params.into_iter().collect(),
|
||||
};
|
||||
match server.set_dataset_embedding_config(&name, &cfg) {
|
||||
if config.dim == 0 {
|
||||
return Ok(Protocol::err("ERR embedding DIM must be > 0"));
|
||||
}
|
||||
match server.set_dataset_embedding_config(&name, &config) {
|
||||
Ok(()) => Ok(Protocol::SimpleString("OK".to_string())),
|
||||
Err(e) => Ok(Protocol::err(&e.0)),
|
||||
}
|
||||
@@ -1466,16 +1498,20 @@ impl Cmd {
|
||||
let mut arr = Vec::new();
|
||||
arr.push(Protocol::BulkString("provider".to_string()));
|
||||
arr.push(Protocol::BulkString(match cfg.provider {
|
||||
EmbeddingProvider::TestHash => "test-hash".to_string(),
|
||||
EmbeddingProvider::ImageTestHash => "testimagehash".to_string(),
|
||||
EmbeddingProvider::LanceFastEmbed => "lancefastembed".to_string(),
|
||||
EmbeddingProvider::LanceOpenAI => "lanceopenai".to_string(),
|
||||
EmbeddingProvider::LanceOther(ref s) => s.clone(),
|
||||
crate::embedding::EmbeddingProvider::openai => "openai".to_string(),
|
||||
crate::embedding::EmbeddingProvider::test => "test".to_string(),
|
||||
crate::embedding::EmbeddingProvider::image_test => "image_test".to_string(),
|
||||
}));
|
||||
arr.push(Protocol::BulkString("model".to_string()));
|
||||
arr.push(Protocol::BulkString(cfg.model.clone()));
|
||||
arr.push(Protocol::BulkString("params".to_string()));
|
||||
arr.push(Protocol::BulkString(serde_json::to_string(&cfg.params).unwrap_or_else(|_| "{}".to_string())));
|
||||
arr.push(Protocol::BulkString("dim".to_string()));
|
||||
arr.push(Protocol::BulkString(cfg.dim.to_string()));
|
||||
arr.push(Protocol::BulkString("endpoint".to_string()));
|
||||
arr.push(Protocol::BulkString(cfg.endpoint.clone().unwrap_or_default()));
|
||||
arr.push(Protocol::BulkString("timeout_ms".to_string()));
|
||||
arr.push(Protocol::BulkString(cfg.timeout_ms.map(|v| v.to_string()).unwrap_or_default()));
|
||||
arr.push(Protocol::BulkString("headers".to_string()));
|
||||
arr.push(Protocol::BulkString(serde_json::to_string(&cfg.headers).unwrap_or_else(|_| "{}".to_string())));
|
||||
Ok(Protocol::Array(arr))
|
||||
}
|
||||
Err(e) => Ok(Protocol::err(&e.0)),
|
||||
|
||||
206
src/embedding.rs
206
src/embedding.rs
@@ -1,4 +1,4 @@
|
||||
// Embedding abstraction and minimal providers.
|
||||
// Embedding abstraction with a single external provider (OpenAI-compatible) and local test providers.
|
||||
|
||||
use std::collections::HashMap;
|
||||
use std::sync::Arc;
|
||||
@@ -7,42 +7,41 @@ use serde::{Deserialize, Serialize};
|
||||
|
||||
use crate::error::DBError;
|
||||
|
||||
// Networking for OpenAI/Azure
|
||||
// Networking for OpenAI-compatible endpoints
|
||||
use std::time::Duration;
|
||||
use ureq::{Agent, AgentBuilder};
|
||||
use serde_json::json;
|
||||
|
||||
/// Provider identifiers. Extend as needed to mirror LanceDB-supported providers.
|
||||
/// Provider identifiers (minimal set).
|
||||
#[derive(Debug, Clone, Serialize, Deserialize, PartialEq, Eq)]
|
||||
#[serde(rename_all = "snake_case")]
|
||||
pub enum EmbeddingProvider {
|
||||
// Deterministic, local-only embedder for CI and offline development (text).
|
||||
TestHash,
|
||||
// Deterministic, local-only embedder for CI and offline development (image).
|
||||
ImageTestHash,
|
||||
// Placeholders for LanceDB-supported providers; implementers can add concrete backends later.
|
||||
LanceFastEmbed,
|
||||
LanceOpenAI,
|
||||
LanceOther(String),
|
||||
/// External HTTP provider compatible with OpenAI's embeddings API.
|
||||
openai,
|
||||
/// Deterministic, local-only embedder for CI and offline development (text).
|
||||
test,
|
||||
/// Deterministic, local-only embedder for CI and offline development (image).
|
||||
image_test,
|
||||
}
|
||||
|
||||
/// Serializable embedding configuration.
|
||||
/// params: arbitrary key-value map for provider-specific knobs (e.g., "dim", "api_key_env", etc.)
|
||||
/// - provider: "openai" | "test" | "image_test"
|
||||
/// - model: provider/model id (e.g., "text-embedding-3-small"), may be ignored by local gateways
|
||||
/// - dim: required output dimension (used to create Lance datasets and validate outputs)
|
||||
/// - endpoint: optional HTTP endpoint (defaults to OpenAI API when provider == openai)
|
||||
/// - headers: optional HTTP headers (e.g., Authorization). If empty and OPENAI_API_KEY is present, Authorization will be inferred.
|
||||
/// - timeout_ms: optional HTTP timeout in milliseconds (for both read and write)
|
||||
#[derive(Debug, Clone, Serialize, Deserialize)]
|
||||
pub struct EmbeddingConfig {
|
||||
pub provider: EmbeddingProvider,
|
||||
pub model: String,
|
||||
pub dim: usize,
|
||||
#[serde(default)]
|
||||
pub params: HashMap<String, String>,
|
||||
}
|
||||
|
||||
impl EmbeddingConfig {
|
||||
pub fn get_param_usize(&self, key: &str) -> Option<usize> {
|
||||
self.params.get(key).and_then(|v| v.parse::<usize>().ok())
|
||||
}
|
||||
pub fn get_param_string(&self, key: &str) -> Option<String> {
|
||||
self.params.get(key).cloned()
|
||||
}
|
||||
pub endpoint: Option<String>,
|
||||
#[serde(default)]
|
||||
pub headers: HashMap<String, String>,
|
||||
#[serde(default)]
|
||||
pub timeout_ms: Option<u64>,
|
||||
}
|
||||
|
||||
/// A provider-agnostic text embedding interface.
|
||||
@@ -59,6 +58,20 @@ pub trait Embedder: Send + Sync {
|
||||
}
|
||||
}
|
||||
|
||||
/// Image embedding interface (separate from text to keep modality-specific inputs).
|
||||
pub trait ImageEmbedder: Send + Sync {
|
||||
/// Human-readable provider/model name
|
||||
fn name(&self) -> String;
|
||||
/// Embedding dimension
|
||||
fn dim(&self) -> usize;
|
||||
/// Embed a single image (raw bytes)
|
||||
fn embed_image(&self, bytes: &[u8]) -> Result<Vec<f32>, DBError>;
|
||||
/// Embed many images; default maps embed_image() over inputs
|
||||
fn embed_many_images(&self, images: &[Vec<u8>]) -> Result<Vec<Vec<f32>>, DBError> {
|
||||
images.iter().map(|b| self.embed_image(b)).collect()
|
||||
}
|
||||
}
|
||||
|
||||
//// ----------------------------- TEXT: deterministic test embedder -----------------------------
|
||||
|
||||
/// Deterministic, no-deps, no-network embedder for CI and offline dev.
|
||||
@@ -88,7 +101,7 @@ impl TestHashEmbedder {
|
||||
|
||||
impl Embedder for TestHashEmbedder {
|
||||
fn name(&self) -> String {
|
||||
format!("test-hash:{}", self.model_name)
|
||||
format!("test:{}", self.model_name)
|
||||
}
|
||||
|
||||
fn dim(&self) -> usize {
|
||||
@@ -117,21 +130,7 @@ impl Embedder for TestHashEmbedder {
|
||||
}
|
||||
}
|
||||
|
||||
//// ----------------------------- IMAGE: trait + deterministic test embedder -----------------------------
|
||||
|
||||
/// Image embedding interface (separate from text to keep modality-specific inputs).
|
||||
pub trait ImageEmbedder: Send + Sync {
|
||||
/// Human-readable provider/model name
|
||||
fn name(&self) -> String;
|
||||
/// Embedding dimension
|
||||
fn dim(&self) -> usize;
|
||||
/// Embed a single image (raw bytes)
|
||||
fn embed_image(&self, bytes: &[u8]) -> Result<Vec<f32>, DBError>;
|
||||
/// Embed many images; default maps embed_image() over inputs
|
||||
fn embed_many_images(&self, images: &[Vec<u8>]) -> Result<Vec<Vec<f32>>, DBError> {
|
||||
images.iter().map(|b| self.embed_image(b)).collect()
|
||||
}
|
||||
}
|
||||
//// ----------------------------- IMAGE: deterministic test embedder -----------------------------
|
||||
|
||||
/// Deterministic image embedder that folds bytes into buckets, applies tanh-like nonlinearity,
|
||||
/// and L2-normalizes. Suitable for CI and offline development.
|
||||
@@ -159,7 +158,7 @@ impl TestImageHashEmbedder {
|
||||
|
||||
impl ImageEmbedder for TestImageHashEmbedder {
|
||||
fn name(&self) -> String {
|
||||
format!("test-image-hash:{}", self.model_name)
|
||||
format!("image_test:{}", self.model_name)
|
||||
}
|
||||
|
||||
fn dim(&self) -> usize {
|
||||
@@ -188,80 +187,45 @@ impl ImageEmbedder for TestImageHashEmbedder {
|
||||
}
|
||||
}
|
||||
|
||||
//// OpenAI embedder (supports OpenAI and Azure OpenAI via REST)
|
||||
//// ----------------------------- OpenAI-compatible HTTP embedder -----------------------------
|
||||
|
||||
struct OpenAIEmbedder {
|
||||
model: String,
|
||||
dim: usize,
|
||||
agent: Agent,
|
||||
endpoint: String,
|
||||
headers: Vec<(String, String)>,
|
||||
use_azure: bool,
|
||||
}
|
||||
|
||||
impl OpenAIEmbedder {
|
||||
fn new_from_config(cfg: &EmbeddingConfig) -> Result<Self, DBError> {
|
||||
// Whether to use Azure OpenAI
|
||||
let use_azure = cfg
|
||||
.get_param_string("use_azure")
|
||||
.map(|s| s.eq_ignore_ascii_case("true"))
|
||||
.unwrap_or(false);
|
||||
|
||||
// Resolve API key (OPENAI_API_KEY or AZURE_OPENAI_API_KEY by default)
|
||||
let api_key_env = cfg
|
||||
.get_param_string("api_key_env")
|
||||
.unwrap_or_else(|| {
|
||||
if use_azure {
|
||||
"AZURE_OPENAI_API_KEY".to_string()
|
||||
} else {
|
||||
"OPENAI_API_KEY".to_string()
|
||||
}
|
||||
});
|
||||
let api_key = std::env::var(&api_key_env)
|
||||
.map_err(|_| DBError(format!("Missing API key in env '{}'", api_key_env)))?;
|
||||
|
||||
// Resolve endpoint
|
||||
// - Standard OpenAI: https://api.openai.com/v1/embeddings (default) or params["base_url"]
|
||||
// - Azure OpenAI: {azure_endpoint}/openai/deployments/{deployment}/embeddings?api-version=...
|
||||
let endpoint = if use_azure {
|
||||
let base = cfg
|
||||
.get_param_string("azure_endpoint")
|
||||
.ok_or_else(|| DBError("Missing 'azure_endpoint' for Azure OpenAI".into()))?;
|
||||
let deployment = cfg
|
||||
.get_param_string("azure_deployment")
|
||||
.unwrap_or_else(|| cfg.model.clone());
|
||||
let api_version = cfg
|
||||
.get_param_string("azure_api_version")
|
||||
.unwrap_or_else(|| "2023-05-15".to_string());
|
||||
format!(
|
||||
"{}/openai/deployments/{}/embeddings?api-version={}",
|
||||
base.trim_end_matches('/'),
|
||||
deployment,
|
||||
api_version
|
||||
)
|
||||
} else {
|
||||
cfg.get_param_string("base_url")
|
||||
.unwrap_or_else(|| "https://api.openai.com/v1/embeddings".to_string())
|
||||
};
|
||||
let endpoint = cfg.endpoint.clone().unwrap_or_else(|| {
|
||||
"https://api.openai.com/v1/embeddings".to_string()
|
||||
});
|
||||
|
||||
// Determine expected dimension (default 1536 for text-embedding-3-small; callers should override if needed)
|
||||
let dim = cfg
|
||||
.get_param_usize("dim")
|
||||
.or_else(|| cfg.get_param_usize("dimensions"))
|
||||
.unwrap_or(1536);
|
||||
// Determine expected dimension (required by config)
|
||||
let dim = cfg.dim;
|
||||
|
||||
// Build an HTTP agent with timeouts (blocking; no tokio runtime involved)
|
||||
let to_ms = cfg.timeout_ms.unwrap_or(30_000);
|
||||
let agent = AgentBuilder::new()
|
||||
.timeout_read(Duration::from_secs(30))
|
||||
.timeout_write(Duration::from_secs(30))
|
||||
.timeout_read(Duration::from_millis(to_ms))
|
||||
.timeout_write(Duration::from_millis(to_ms))
|
||||
.build();
|
||||
|
||||
// Headers
|
||||
let mut headers: Vec<(String, String)> = Vec::new();
|
||||
headers.push(("Content-Type".to_string(), "application/json".to_string()));
|
||||
if use_azure {
|
||||
headers.push(("api-key".to_string(), api_key));
|
||||
} else {
|
||||
headers.push(("Authorization".to_string(), format!("Bearer {}", api_key)));
|
||||
// Headers: start from cfg.headers, and add Authorization from env if absent and available
|
||||
let mut headers: Vec<(String, String)> =
|
||||
cfg.headers.iter().map(|(k, v)| (k.clone(), v.clone())).collect();
|
||||
|
||||
if !headers.iter().any(|(k, _)| k.eq_ignore_ascii_case("content-type")) {
|
||||
headers.push(("Content-Type".to_string(), "application/json".to_string()));
|
||||
}
|
||||
|
||||
if !headers.iter().any(|(k, _)| k.eq_ignore_ascii_case("authorization")) {
|
||||
if let Ok(key) = std::env::var("OPENAI_API_KEY") {
|
||||
headers.push(("Authorization".to_string(), format!("Bearer {}", key)));
|
||||
}
|
||||
}
|
||||
|
||||
Ok(Self {
|
||||
@@ -270,19 +234,12 @@ impl OpenAIEmbedder {
|
||||
agent,
|
||||
endpoint,
|
||||
headers,
|
||||
use_azure,
|
||||
})
|
||||
}
|
||||
|
||||
fn request_many(&self, inputs: &[String]) -> Result<Vec<Vec<f32>>, DBError> {
|
||||
// Compose request body:
|
||||
// - Standard OpenAI: { "model": ..., "input": [...], "dimensions": dim? }
|
||||
// - Azure: { "input": [...], "dimensions": dim? } (model from deployment)
|
||||
let mut body = if self.use_azure {
|
||||
json!({ "input": inputs })
|
||||
} else {
|
||||
json!({ "model": self.model, "input": inputs })
|
||||
};
|
||||
// Compose request body (OpenAI-compatible)
|
||||
let mut body = json!({ "model": self.model, "input": inputs });
|
||||
if self.dim > 0 {
|
||||
body.as_object_mut()
|
||||
.unwrap()
|
||||
@@ -331,7 +288,7 @@ impl OpenAIEmbedder {
|
||||
}
|
||||
if self.dim > 0 && v.len() != self.dim {
|
||||
return Err(DBError(format!(
|
||||
"Embedding dimension mismatch: expected {}, got {}. Configure 'dim' or 'dimensions' to match output.",
|
||||
"Embedding dimension mismatch: expected {}, got {}. Configure 'dim' to match output.",
|
||||
self.dim, v.len()
|
||||
)));
|
||||
}
|
||||
@@ -343,11 +300,7 @@ impl OpenAIEmbedder {
|
||||
|
||||
impl Embedder for OpenAIEmbedder {
|
||||
fn name(&self) -> String {
|
||||
if self.use_azure {
|
||||
format!("azure-openai:{}", self.model)
|
||||
} else {
|
||||
format!("openai:{}", self.model)
|
||||
}
|
||||
format!("openai:{}", self.model)
|
||||
}
|
||||
|
||||
fn dim(&self) -> usize {
|
||||
@@ -368,38 +321,33 @@ impl Embedder for OpenAIEmbedder {
|
||||
}
|
||||
|
||||
/// Create an embedder instance from a config.
|
||||
/// - TestHash: uses params["dim"] or defaults to 64
|
||||
/// - LanceOpenAI: uses OpenAI (or Azure OpenAI) embeddings REST API
|
||||
/// - Other Lance providers can be added similarly
|
||||
/// - openai: uses OpenAI-compatible embeddings REST API (endpoint override supported)
|
||||
/// - test: deterministic local text embedder (no network)
|
||||
/// - image_test: not valid for text (use create_image_embedder)
|
||||
pub fn create_embedder(config: &EmbeddingConfig) -> Result<Arc<dyn Embedder>, DBError> {
|
||||
match &config.provider {
|
||||
EmbeddingProvider::TestHash => {
|
||||
let dim = config.get_param_usize("dim").unwrap_or(64);
|
||||
Ok(Arc::new(TestHashEmbedder::new(dim, config.model.clone())))
|
||||
}
|
||||
EmbeddingProvider::LanceOpenAI => {
|
||||
EmbeddingProvider::openai => {
|
||||
let inner = OpenAIEmbedder::new_from_config(config)?;
|
||||
Ok(Arc::new(inner))
|
||||
}
|
||||
EmbeddingProvider::ImageTestHash => {
|
||||
EmbeddingProvider::test => {
|
||||
Ok(Arc::new(TestHashEmbedder::new(config.dim, config.model.clone())))
|
||||
}
|
||||
EmbeddingProvider::image_test => {
|
||||
Err(DBError("Use create_image_embedder() for image providers".into()))
|
||||
}
|
||||
EmbeddingProvider::LanceFastEmbed => Err(DBError("LanceFastEmbed provider not yet implemented in Rust embedding layer; configure 'test-hash' or use 'openai'".into())),
|
||||
EmbeddingProvider::LanceOther(p) => Err(DBError(format!("Lance provider '{}' not implemented; configure 'openai' or 'test-hash'", p))),
|
||||
}
|
||||
}
|
||||
|
||||
/// Create an image embedder instance from a config.
|
||||
/// - image_test: deterministic local image embedder
|
||||
pub fn create_image_embedder(config: &EmbeddingConfig) -> Result<Arc<dyn ImageEmbedder>, DBError> {
|
||||
match &config.provider {
|
||||
EmbeddingProvider::ImageTestHash => {
|
||||
let dim = config.get_param_usize("dim").unwrap_or(512);
|
||||
Ok(Arc::new(TestImageHashEmbedder::new(dim, config.model.clone())))
|
||||
EmbeddingProvider::image_test => {
|
||||
Ok(Arc::new(TestImageHashEmbedder::new(config.dim, config.model.clone())))
|
||||
}
|
||||
EmbeddingProvider::TestHash | EmbeddingProvider::LanceOpenAI => {
|
||||
Err(DBError("Configured text provider; dataset expects image provider (e.g., 'testimagehash')".into()))
|
||||
EmbeddingProvider::test | EmbeddingProvider::openai => {
|
||||
Err(DBError("Configured text provider; dataset expects image provider (e.g., 'image_test')".into()))
|
||||
}
|
||||
EmbeddingProvider::LanceFastEmbed => Err(DBError("Image provider 'lancefastembed' not yet implemented".into())),
|
||||
EmbeddingProvider::LanceOther(p) => Err(DBError(format!("Image provider '{}' not implemented; use 'testimagehash' for now", p))),
|
||||
}
|
||||
}
|
||||
46
src/rpc.rs
46
src/rpc.rs
@@ -9,7 +9,7 @@ use sha2::{Digest, Sha256};
|
||||
use crate::server::Server;
|
||||
use crate::options::DBOption;
|
||||
use crate::admin_meta;
|
||||
use crate::embedding::{EmbeddingConfig, EmbeddingProvider};
|
||||
use crate::embedding::EmbeddingConfig;
|
||||
use base64::{engine::general_purpose, Engine as _};
|
||||
|
||||
/// Database backend types
|
||||
@@ -248,9 +248,7 @@ pub trait Rpc {
|
||||
&self,
|
||||
db_id: u64,
|
||||
name: String,
|
||||
provider: String,
|
||||
model: String,
|
||||
params: Option<HashMap<String, String>>,
|
||||
config: EmbeddingConfig,
|
||||
) -> RpcResult<bool>;
|
||||
|
||||
/// Get per-dataset embedding configuration
|
||||
@@ -1008,9 +1006,7 @@ impl RpcServer for RpcServerImpl {
|
||||
&self,
|
||||
db_id: u64,
|
||||
name: String,
|
||||
provider: String,
|
||||
model: String,
|
||||
params: Option<HashMap<String, String>>,
|
||||
config: EmbeddingConfig,
|
||||
) -> RpcResult<bool> {
|
||||
let server = self.get_or_create_server(db_id).await?;
|
||||
if db_id == 0 {
|
||||
@@ -1022,19 +1018,17 @@ impl RpcServer for RpcServerImpl {
|
||||
if !server.has_write_permission() {
|
||||
return Err(jsonrpsee::types::ErrorObjectOwned::owned(-32000, "write permission denied", None::<()>));
|
||||
}
|
||||
let prov = match provider.to_lowercase().as_str() {
|
||||
"test-hash" | "testhash" => EmbeddingProvider::TestHash,
|
||||
"testimagehash" | "image-test-hash" | "imagetesthash" => EmbeddingProvider::ImageTestHash,
|
||||
"fastembed" | "lancefastembed" => EmbeddingProvider::LanceFastEmbed,
|
||||
"openai" | "lanceopenai" => EmbeddingProvider::LanceOpenAI,
|
||||
other => EmbeddingProvider::LanceOther(other.to_string()),
|
||||
};
|
||||
let cfg = EmbeddingConfig {
|
||||
provider: prov,
|
||||
model,
|
||||
params: params.unwrap_or_default(),
|
||||
};
|
||||
server.set_dataset_embedding_config(&name, &cfg)
|
||||
// Validate provider and dimension (only a minimal set is allowed for now)
|
||||
match config.provider {
|
||||
crate::embedding::EmbeddingProvider::openai
|
||||
| crate::embedding::EmbeddingProvider::test
|
||||
| crate::embedding::EmbeddingProvider::image_test => {}
|
||||
}
|
||||
if config.dim == 0 {
|
||||
return Err(jsonrpsee::types::ErrorObjectOwned::owned(-32000, "Invalid embedding config: dim must be > 0", None::<()>));
|
||||
}
|
||||
|
||||
server.set_dataset_embedding_config(&name, &config)
|
||||
.map_err(|e| jsonrpsee::types::ErrorObjectOwned::owned(-32000, e.0, None::<()>))?;
|
||||
Ok(true)
|
||||
}
|
||||
@@ -1056,17 +1050,7 @@ impl RpcServer for RpcServerImpl {
|
||||
}
|
||||
let cfg = server.get_dataset_embedding_config(&name)
|
||||
.map_err(|e| jsonrpsee::types::ErrorObjectOwned::owned(-32000, e.0, None::<()>))?;
|
||||
Ok(serde_json::json!({
|
||||
"provider": match cfg.provider {
|
||||
EmbeddingProvider::TestHash => "test-hash",
|
||||
EmbeddingProvider::ImageTestHash => "testimagehash",
|
||||
EmbeddingProvider::LanceFastEmbed => "lancefastembed",
|
||||
EmbeddingProvider::LanceOpenAI => "lanceopenai",
|
||||
EmbeddingProvider::LanceOther(ref s) => s,
|
||||
},
|
||||
"model": cfg.model,
|
||||
"params": cfg.params
|
||||
}))
|
||||
Ok(serde_json::to_value(&cfg).unwrap_or(serde_json::json!({})))
|
||||
}
|
||||
|
||||
async fn lance_store_text(
|
||||
|
||||
Reference in New Issue
Block a user