WIP3 implemeting lancedb
This commit is contained in:
200
src/embedding.rs
200
src/embedding.rs
@@ -21,6 +21,12 @@ use serde::{Deserialize, Serialize};
|
||||
|
||||
use crate::error::DBError;
|
||||
|
||||
// Networking for OpenAI/Azure
|
||||
use std::time::Duration;
|
||||
use reqwest::blocking::Client;
|
||||
use reqwest::header::{HeaderMap, HeaderName, HeaderValue, CONTENT_TYPE, AUTHORIZATION};
|
||||
use serde_json::json;
|
||||
|
||||
/// Provider identifiers. Extend as needed to mirror LanceDB-supported providers.
|
||||
#[derive(Debug, Clone, Serialize, Deserialize, PartialEq, Eq)]
|
||||
#[serde(rename_all = "snake_case")]
|
||||
@@ -122,17 +128,203 @@ impl Embedder for TestHashEmbedder {
|
||||
}
|
||||
}
|
||||
|
||||
//// OpenAI embedder (supports OpenAI and Azure OpenAI via REST)
|
||||
struct OpenAIEmbedder {
|
||||
model: String,
|
||||
dim: usize,
|
||||
client: Client,
|
||||
endpoint: String,
|
||||
auth_header_name: HeaderName,
|
||||
auth_header_value: HeaderValue,
|
||||
use_azure: bool,
|
||||
}
|
||||
|
||||
impl OpenAIEmbedder {
|
||||
fn new_from_config(cfg: &EmbeddingConfig) -> Result<Self, DBError> {
|
||||
// Whether to use Azure OpenAI
|
||||
let use_azure = cfg
|
||||
.get_param_string("use_azure")
|
||||
.map(|s| s.eq_ignore_ascii_case("true"))
|
||||
.unwrap_or(false);
|
||||
|
||||
// Resolve API key (OPENAI_API_KEY or AZURE_OPENAI_API_KEY by default)
|
||||
let api_key_env = cfg
|
||||
.get_param_string("api_key_env")
|
||||
.unwrap_or_else(|| {
|
||||
if use_azure {
|
||||
"AZURE_OPENAI_API_KEY".to_string()
|
||||
} else {
|
||||
"OPENAI_API_KEY".to_string()
|
||||
}
|
||||
});
|
||||
let api_key = std::env::var(&api_key_env)
|
||||
.map_err(|_| DBError(format!("Missing API key in env '{}'", api_key_env)))?;
|
||||
|
||||
// Resolve endpoint
|
||||
// - Standard OpenAI: https://api.openai.com/v1/embeddings (default) or params["base_url"]
|
||||
// - Azure OpenAI: {azure_endpoint}/openai/deployments/{deployment}/embeddings?api-version=...
|
||||
let endpoint = if use_azure {
|
||||
let base = cfg
|
||||
.get_param_string("azure_endpoint")
|
||||
.ok_or_else(|| DBError("Missing 'azure_endpoint' for Azure OpenAI".into()))?;
|
||||
let deployment = cfg
|
||||
.get_param_string("azure_deployment")
|
||||
.unwrap_or_else(|| cfg.model.clone());
|
||||
let api_version = cfg
|
||||
.get_param_string("azure_api_version")
|
||||
.unwrap_or_else(|| "2023-05-15".to_string());
|
||||
format!(
|
||||
"{}/openai/deployments/{}/embeddings?api-version={}",
|
||||
base.trim_end_matches('/'),
|
||||
deployment,
|
||||
api_version
|
||||
)
|
||||
} else {
|
||||
cfg.get_param_string("base_url")
|
||||
.unwrap_or_else(|| "https://api.openai.com/v1/embeddings".to_string())
|
||||
};
|
||||
|
||||
// Determine expected dimension:
|
||||
// - Prefer params["dim"] or params["dimensions"]
|
||||
// - Else default to 1536 (common for text-embedding-3-small; callers should override if needed)
|
||||
let dim = cfg
|
||||
.get_param_usize("dim")
|
||||
.or_else(|| cfg.get_param_usize("dimensions"))
|
||||
.unwrap_or(1536);
|
||||
|
||||
// Build default headers
|
||||
let mut headers = HeaderMap::new();
|
||||
headers.insert(CONTENT_TYPE, HeaderValue::from_static("application/json"));
|
||||
let (auth_name, auth_val) = if use_azure {
|
||||
let name = HeaderName::from_static("api-key");
|
||||
let val = HeaderValue::from_str(&api_key)
|
||||
.map_err(|_| DBError("Invalid API key header value".into()))?;
|
||||
(name, val)
|
||||
} else {
|
||||
let bearer = format!("Bearer {}", api_key);
|
||||
(AUTHORIZATION, HeaderValue::from_str(&bearer).map_err(|_| DBError("Invalid Authorization header".into()))?)
|
||||
};
|
||||
|
||||
let client = Client::builder()
|
||||
.timeout(Duration::from_secs(30))
|
||||
.default_headers(headers)
|
||||
.build()
|
||||
.map_err(|e| DBError(format!("Failed to build HTTP client: {}", e)))?;
|
||||
|
||||
Ok(Self {
|
||||
model: cfg.model.clone(),
|
||||
dim,
|
||||
client,
|
||||
endpoint,
|
||||
auth_header_name: auth_name,
|
||||
auth_header_value: auth_val,
|
||||
use_azure,
|
||||
})
|
||||
}
|
||||
|
||||
fn request_many(&self, inputs: &[String]) -> Result<Vec<Vec<f32>>, DBError> {
|
||||
// Compose request body:
|
||||
// - Standard OpenAI: { "model": ..., "input": [...], "dimensions": dim? }
|
||||
// - Azure: { "input": [...], "dimensions": dim? } (model from deployment)
|
||||
let mut body = if self.use_azure {
|
||||
json!({ "input": inputs })
|
||||
} else {
|
||||
json!({ "model": self.model, "input": inputs })
|
||||
};
|
||||
if self.dim > 0 {
|
||||
body.as_object_mut()
|
||||
.unwrap()
|
||||
.insert("dimensions".to_string(), json!(self.dim));
|
||||
}
|
||||
|
||||
let mut req = self.client.post(&self.endpoint);
|
||||
// Add auth header dynamically
|
||||
req = req.header(self.auth_header_name.clone(), self.auth_header_value.clone());
|
||||
|
||||
let resp = req
|
||||
.json(&body)
|
||||
.send()
|
||||
.map_err(|e| DBError(format!("HTTP request failed: {}", e)))?;
|
||||
if !resp.status().is_success() {
|
||||
let code = resp.status();
|
||||
let text = resp.text().unwrap_or_default();
|
||||
return Err(DBError(format!("Embeddings API error {}: {}", code, text)));
|
||||
}
|
||||
let val: serde_json::Value = resp
|
||||
.json()
|
||||
.map_err(|e| DBError(format!("Invalid JSON from embeddings API: {}", e)))?;
|
||||
|
||||
let data = val
|
||||
.get("data")
|
||||
.and_then(|d| d.as_array())
|
||||
.ok_or_else(|| DBError("Embeddings API response missing 'data' array".into()))?;
|
||||
|
||||
let mut out: Vec<Vec<f32>> = Vec::with_capacity(data.len());
|
||||
for item in data {
|
||||
let emb = item
|
||||
.get("embedding")
|
||||
.and_then(|e| e.as_array())
|
||||
.ok_or_else(|| DBError("Embeddings API item missing 'embedding'".into()))?;
|
||||
let mut v: Vec<f32> = Vec::with_capacity(emb.len());
|
||||
for n in emb {
|
||||
let f = n
|
||||
.as_f64()
|
||||
.ok_or_else(|| DBError("Embedding element is not a number".into()))?;
|
||||
v.push(f as f32);
|
||||
}
|
||||
if self.dim > 0 && v.len() != self.dim {
|
||||
return Err(DBError(format!(
|
||||
"Embedding dimension mismatch: expected {}, got {}. Configure 'dim' or 'dimensions' to match output.",
|
||||
self.dim, v.len()
|
||||
)));
|
||||
}
|
||||
out.push(v);
|
||||
}
|
||||
Ok(out)
|
||||
}
|
||||
}
|
||||
|
||||
impl Embedder for OpenAIEmbedder {
|
||||
fn name(&self) -> String {
|
||||
if self.use_azure {
|
||||
format!("azure-openai:{}", self.model)
|
||||
} else {
|
||||
format!("openai:{}", self.model)
|
||||
}
|
||||
}
|
||||
|
||||
fn dim(&self) -> usize {
|
||||
self.dim
|
||||
}
|
||||
|
||||
fn embed(&self, text: &str) -> Result<Vec<f32>, DBError> {
|
||||
let v = self.request_many(&[text.to_string()])?;
|
||||
Ok(v.into_iter().next().unwrap_or_else(|| vec![0.0; self.dim]))
|
||||
}
|
||||
|
||||
fn embed_many(&self, texts: &[String]) -> Result<Vec<Vec<f32>>, DBError> {
|
||||
if texts.is_empty() {
|
||||
return Ok(vec![]);
|
||||
}
|
||||
self.request_many(texts)
|
||||
}
|
||||
}
|
||||
|
||||
/// Create an embedder instance from a config.
|
||||
/// - TestHash: uses params["dim"] or defaults to 64
|
||||
/// - Lance* providers: return an explicit error for now; implementers can wire these up
|
||||
/// - LanceOpenAI: uses OpenAI (or Azure OpenAI) embeddings REST API
|
||||
/// - Other Lance providers can be added similarly
|
||||
pub fn create_embedder(config: &EmbeddingConfig) -> Result<Arc<dyn Embedder>, DBError> {
|
||||
match &config.provider {
|
||||
EmbeddingProvider::TestHash => {
|
||||
let dim = config.get_param_usize("dim").unwrap_or(64);
|
||||
Ok(Arc::new(TestHashEmbedder::new(dim, config.model.clone())))
|
||||
}
|
||||
EmbeddingProvider::LanceFastEmbed => Err(DBError("LanceFastEmbed provider not yet implemented in Rust embedding layer; configure 'test-hash' or implement a Lance-backed provider".into())),
|
||||
EmbeddingProvider::LanceOpenAI => Err(DBError("LanceOpenAI provider not yet implemented in Rust embedding layer; configure 'test-hash' or implement a Lance-backed provider".into())),
|
||||
EmbeddingProvider::LanceOther(p) => Err(DBError(format!("Lance provider '{}' not implemented; configure 'test-hash' or implement a Lance-backed provider", p))),
|
||||
EmbeddingProvider::LanceOpenAI => {
|
||||
let inner = OpenAIEmbedder::new_from_config(config)?;
|
||||
Ok(Arc::new(inner))
|
||||
}
|
||||
EmbeddingProvider::LanceFastEmbed => Err(DBError("LanceFastEmbed provider not yet implemented in Rust embedding layer; configure 'test-hash' or use 'openai'".into())),
|
||||
EmbeddingProvider::LanceOther(p) => Err(DBError(format!("Lance provider '{}' not implemented; configure 'openai' or 'test-hash'", p))),
|
||||
}
|
||||
}
|
||||
Reference in New Issue
Block a user