WIP3 implemeting lancedb

This commit is contained in:
Maxime Van Hees
2025-09-29 14:55:41 +02:00
parent cf66f4c304
commit 644946f1ca
4 changed files with 799 additions and 4 deletions

View File

@@ -21,6 +21,12 @@ use serde::{Deserialize, Serialize};
use crate::error::DBError;
// Networking for OpenAI/Azure
use std::time::Duration;
use reqwest::blocking::Client;
use reqwest::header::{HeaderMap, HeaderName, HeaderValue, CONTENT_TYPE, AUTHORIZATION};
use serde_json::json;
/// Provider identifiers. Extend as needed to mirror LanceDB-supported providers.
#[derive(Debug, Clone, Serialize, Deserialize, PartialEq, Eq)]
#[serde(rename_all = "snake_case")]
@@ -122,17 +128,203 @@ impl Embedder for TestHashEmbedder {
}
}
//// OpenAI embedder (supports OpenAI and Azure OpenAI via REST)
struct OpenAIEmbedder {
model: String,
dim: usize,
client: Client,
endpoint: String,
auth_header_name: HeaderName,
auth_header_value: HeaderValue,
use_azure: bool,
}
impl OpenAIEmbedder {
fn new_from_config(cfg: &EmbeddingConfig) -> Result<Self, DBError> {
// Whether to use Azure OpenAI
let use_azure = cfg
.get_param_string("use_azure")
.map(|s| s.eq_ignore_ascii_case("true"))
.unwrap_or(false);
// Resolve API key (OPENAI_API_KEY or AZURE_OPENAI_API_KEY by default)
let api_key_env = cfg
.get_param_string("api_key_env")
.unwrap_or_else(|| {
if use_azure {
"AZURE_OPENAI_API_KEY".to_string()
} else {
"OPENAI_API_KEY".to_string()
}
});
let api_key = std::env::var(&api_key_env)
.map_err(|_| DBError(format!("Missing API key in env '{}'", api_key_env)))?;
// Resolve endpoint
// - Standard OpenAI: https://api.openai.com/v1/embeddings (default) or params["base_url"]
// - Azure OpenAI: {azure_endpoint}/openai/deployments/{deployment}/embeddings?api-version=...
let endpoint = if use_azure {
let base = cfg
.get_param_string("azure_endpoint")
.ok_or_else(|| DBError("Missing 'azure_endpoint' for Azure OpenAI".into()))?;
let deployment = cfg
.get_param_string("azure_deployment")
.unwrap_or_else(|| cfg.model.clone());
let api_version = cfg
.get_param_string("azure_api_version")
.unwrap_or_else(|| "2023-05-15".to_string());
format!(
"{}/openai/deployments/{}/embeddings?api-version={}",
base.trim_end_matches('/'),
deployment,
api_version
)
} else {
cfg.get_param_string("base_url")
.unwrap_or_else(|| "https://api.openai.com/v1/embeddings".to_string())
};
// Determine expected dimension:
// - Prefer params["dim"] or params["dimensions"]
// - Else default to 1536 (common for text-embedding-3-small; callers should override if needed)
let dim = cfg
.get_param_usize("dim")
.or_else(|| cfg.get_param_usize("dimensions"))
.unwrap_or(1536);
// Build default headers
let mut headers = HeaderMap::new();
headers.insert(CONTENT_TYPE, HeaderValue::from_static("application/json"));
let (auth_name, auth_val) = if use_azure {
let name = HeaderName::from_static("api-key");
let val = HeaderValue::from_str(&api_key)
.map_err(|_| DBError("Invalid API key header value".into()))?;
(name, val)
} else {
let bearer = format!("Bearer {}", api_key);
(AUTHORIZATION, HeaderValue::from_str(&bearer).map_err(|_| DBError("Invalid Authorization header".into()))?)
};
let client = Client::builder()
.timeout(Duration::from_secs(30))
.default_headers(headers)
.build()
.map_err(|e| DBError(format!("Failed to build HTTP client: {}", e)))?;
Ok(Self {
model: cfg.model.clone(),
dim,
client,
endpoint,
auth_header_name: auth_name,
auth_header_value: auth_val,
use_azure,
})
}
fn request_many(&self, inputs: &[String]) -> Result<Vec<Vec<f32>>, DBError> {
// Compose request body:
// - Standard OpenAI: { "model": ..., "input": [...], "dimensions": dim? }
// - Azure: { "input": [...], "dimensions": dim? } (model from deployment)
let mut body = if self.use_azure {
json!({ "input": inputs })
} else {
json!({ "model": self.model, "input": inputs })
};
if self.dim > 0 {
body.as_object_mut()
.unwrap()
.insert("dimensions".to_string(), json!(self.dim));
}
let mut req = self.client.post(&self.endpoint);
// Add auth header dynamically
req = req.header(self.auth_header_name.clone(), self.auth_header_value.clone());
let resp = req
.json(&body)
.send()
.map_err(|e| DBError(format!("HTTP request failed: {}", e)))?;
if !resp.status().is_success() {
let code = resp.status();
let text = resp.text().unwrap_or_default();
return Err(DBError(format!("Embeddings API error {}: {}", code, text)));
}
let val: serde_json::Value = resp
.json()
.map_err(|e| DBError(format!("Invalid JSON from embeddings API: {}", e)))?;
let data = val
.get("data")
.and_then(|d| d.as_array())
.ok_or_else(|| DBError("Embeddings API response missing 'data' array".into()))?;
let mut out: Vec<Vec<f32>> = Vec::with_capacity(data.len());
for item in data {
let emb = item
.get("embedding")
.and_then(|e| e.as_array())
.ok_or_else(|| DBError("Embeddings API item missing 'embedding'".into()))?;
let mut v: Vec<f32> = Vec::with_capacity(emb.len());
for n in emb {
let f = n
.as_f64()
.ok_or_else(|| DBError("Embedding element is not a number".into()))?;
v.push(f as f32);
}
if self.dim > 0 && v.len() != self.dim {
return Err(DBError(format!(
"Embedding dimension mismatch: expected {}, got {}. Configure 'dim' or 'dimensions' to match output.",
self.dim, v.len()
)));
}
out.push(v);
}
Ok(out)
}
}
impl Embedder for OpenAIEmbedder {
fn name(&self) -> String {
if self.use_azure {
format!("azure-openai:{}", self.model)
} else {
format!("openai:{}", self.model)
}
}
fn dim(&self) -> usize {
self.dim
}
fn embed(&self, text: &str) -> Result<Vec<f32>, DBError> {
let v = self.request_many(&[text.to_string()])?;
Ok(v.into_iter().next().unwrap_or_else(|| vec![0.0; self.dim]))
}
fn embed_many(&self, texts: &[String]) -> Result<Vec<Vec<f32>>, DBError> {
if texts.is_empty() {
return Ok(vec![]);
}
self.request_many(texts)
}
}
/// Create an embedder instance from a config.
/// - TestHash: uses params["dim"] or defaults to 64
/// - Lance* providers: return an explicit error for now; implementers can wire these up
/// - LanceOpenAI: uses OpenAI (or Azure OpenAI) embeddings REST API
/// - Other Lance providers can be added similarly
pub fn create_embedder(config: &EmbeddingConfig) -> Result<Arc<dyn Embedder>, DBError> {
match &config.provider {
EmbeddingProvider::TestHash => {
let dim = config.get_param_usize("dim").unwrap_or(64);
Ok(Arc::new(TestHashEmbedder::new(dim, config.model.clone())))
}
EmbeddingProvider::LanceFastEmbed => Err(DBError("LanceFastEmbed provider not yet implemented in Rust embedding layer; configure 'test-hash' or implement a Lance-backed provider".into())),
EmbeddingProvider::LanceOpenAI => Err(DBError("LanceOpenAI provider not yet implemented in Rust embedding layer; configure 'test-hash' or implement a Lance-backed provider".into())),
EmbeddingProvider::LanceOther(p) => Err(DBError(format!("Lance provider '{}' not implemented; configure 'test-hash' or implement a Lance-backed provider", p))),
EmbeddingProvider::LanceOpenAI => {
let inner = OpenAIEmbedder::new_from_config(config)?;
Ok(Arc::new(inner))
}
EmbeddingProvider::LanceFastEmbed => Err(DBError("LanceFastEmbed provider not yet implemented in Rust embedding layer; configure 'test-hash' or use 'openai'".into())),
EmbeddingProvider::LanceOther(p) => Err(DBError(format!("Lance provider '{}' not implemented; configure 'openai' or 'test-hash'", p))),
}
}