WIP4 implementation lanceDB: removed blocking Tokio runtime usage during embeddings and isolated all embedding work off the async runtime

This commit is contained in:
Maxime Van Hees
2025-09-29 15:54:12 +02:00
parent 644946f1ca
commit 4aa49e0d5c
5 changed files with 132 additions and 211 deletions

View File

@@ -1363,9 +1363,20 @@ impl Cmd {
if !server.has_write_permission() {
return Ok(Protocol::err("ERR write permission denied"));
}
// Resolve embedder and embed text
// Resolve embedder and embed text on a plain OS thread to avoid tokio runtime panics from reqwest::blocking
let embedder = server.get_embedder_for(&name)?;
let vector = embedder.embed(&text)?;
let (tx, rx) = tokio::sync::oneshot::channel();
let emb_arc = embedder.clone();
let text_cl = text.clone();
std::thread::spawn(move || {
let res = emb_arc.embed(&text_cl);
let _ = tx.send(res);
});
let vector = match rx.await {
Ok(Ok(v)) => v,
Ok(Err(e)) => return Ok(Protocol::err(&e.0)),
Err(recv_err) => return Ok(Protocol::err(&format!("ERR embedding thread error: {}", recv_err))),
};
let meta_map: std::collections::HashMap<String, String> = meta.into_iter().collect();
match server.lance_store()?.store_vector(&name, &id, vector, meta_map, Some(text)).await {
Ok(()) => Ok(Protocol::SimpleString("OK".to_string())),
@@ -1373,9 +1384,20 @@ impl Cmd {
}
}
Cmd::LanceSearchText { name, text, k, filter, return_fields } => {
// Resolve embedder and embed query text
// Resolve embedder and embed query text on a plain OS thread
let embedder = server.get_embedder_for(&name)?;
let qv = embedder.embed(&text)?;
let (tx, rx) = tokio::sync::oneshot::channel();
let emb_arc = embedder.clone();
let text_cl = text.clone();
std::thread::spawn(move || {
let res = emb_arc.embed(&text_cl);
let _ = tx.send(res);
});
let qv = match rx.await {
Ok(Ok(v)) => v,
Ok(Err(e)) => return Ok(Protocol::err(&e.0)),
Err(recv_err) => return Ok(Protocol::err(&format!("ERR embedding thread error: {}", recv_err))),
};
match server.lance_store()?.search_vectors(&name, qv, k, filter, return_fields).await {
Ok(results) => {
// Encode as array of [id, score, [k1, v1, k2, v2, ...]]

View File

@@ -23,8 +23,7 @@ use crate::error::DBError;
// Networking for OpenAI/Azure
use std::time::Duration;
use reqwest::blocking::Client;
use reqwest::header::{HeaderMap, HeaderName, HeaderValue, CONTENT_TYPE, AUTHORIZATION};
use ureq::{Agent, AgentBuilder};
use serde_json::json;
/// Provider identifiers. Extend as needed to mirror LanceDB-supported providers.
@@ -132,10 +131,9 @@ impl Embedder for TestHashEmbedder {
struct OpenAIEmbedder {
model: String,
dim: usize,
client: Client,
agent: Agent,
endpoint: String,
auth_header_name: HeaderName,
auth_header_value: HeaderValue,
headers: Vec<(String, String)>,
use_azure: bool,
}
@@ -184,40 +182,33 @@ impl OpenAIEmbedder {
.unwrap_or_else(|| "https://api.openai.com/v1/embeddings".to_string())
};
// Determine expected dimension:
// - Prefer params["dim"] or params["dimensions"]
// - Else default to 1536 (common for text-embedding-3-small; callers should override if needed)
// Determine expected dimension (default 1536 for text-embedding-3-small; callers should override if needed)
let dim = cfg
.get_param_usize("dim")
.or_else(|| cfg.get_param_usize("dimensions"))
.unwrap_or(1536);
// Build default headers
let mut headers = HeaderMap::new();
headers.insert(CONTENT_TYPE, HeaderValue::from_static("application/json"));
let (auth_name, auth_val) = if use_azure {
let name = HeaderName::from_static("api-key");
let val = HeaderValue::from_str(&api_key)
.map_err(|_| DBError("Invalid API key header value".into()))?;
(name, val)
} else {
let bearer = format!("Bearer {}", api_key);
(AUTHORIZATION, HeaderValue::from_str(&bearer).map_err(|_| DBError("Invalid Authorization header".into()))?)
};
// Build an HTTP agent with timeouts (blocking; no tokio runtime involved)
let agent = AgentBuilder::new()
.timeout_read(Duration::from_secs(30))
.timeout_write(Duration::from_secs(30))
.build();
let client = Client::builder()
.timeout(Duration::from_secs(30))
.default_headers(headers)
.build()
.map_err(|e| DBError(format!("Failed to build HTTP client: {}", e)))?;
// Headers
let mut headers: Vec<(String, String)> = Vec::new();
headers.push(("Content-Type".to_string(), "application/json".to_string()));
if use_azure {
headers.push(("api-key".to_string(), api_key));
} else {
headers.push(("Authorization".to_string(), format!("Bearer {}", api_key)));
}
Ok(Self {
model: cfg.model.clone(),
dim,
client,
agent,
endpoint,
auth_header_name: auth_name,
auth_header_value: auth_val,
headers,
use_azure,
})
}
@@ -237,21 +228,26 @@ impl OpenAIEmbedder {
.insert("dimensions".to_string(), json!(self.dim));
}
let mut req = self.client.post(&self.endpoint);
// Add auth header dynamically
req = req.header(self.auth_header_name.clone(), self.auth_header_value.clone());
let resp = req
.json(&body)
.send()
.map_err(|e| DBError(format!("HTTP request failed: {}", e)))?;
if !resp.status().is_success() {
let code = resp.status();
let text = resp.text().unwrap_or_default();
return Err(DBError(format!("Embeddings API error {}: {}", code, text)));
// Build request
let mut req = self.agent.post(&self.endpoint);
for (k, v) in &self.headers {
req = req.set(k, v);
}
let val: serde_json::Value = resp
.json()
// Send and handle errors
let resp = req.send_json(body);
let text = match resp {
Ok(r) => r
.into_string()
.map_err(|e| DBError(format!("Failed to read embeddings response: {}", e)))?,
Err(ureq::Error::Status(code, r)) => {
let body = r.into_string().unwrap_or_default();
return Err(DBError(format!("Embeddings API error {}: {}", code, body)));
}
Err(e) => return Err(DBError(format!("HTTP request failed: {}", e))),
};
let val: serde_json::Value = serde_json::from_str(&text)
.map_err(|e| DBError(format!("Invalid JSON from embeddings API: {}", e)))?;
let data = val

View File

@@ -1057,10 +1057,22 @@ impl RpcServer for RpcServerImpl {
if !server.has_write_permission() {
return Err(jsonrpsee::types::ErrorObjectOwned::owned(-32000, "write permission denied", None::<()>));
}
// Resolve embedder and run blocking embedding off the async runtime
// Resolve embedder and run embedding on a plain OS thread (avoid dropping any runtime in async context)
let embedder = server.get_embedder_for(&name)
.map_err(|e| jsonrpsee::types::ErrorObjectOwned::owned(-32000, e.0, None::<()>))?;
let vector = embedder.embed(&text)
.map_err(|e| jsonrpsee::types::ErrorObjectOwned::owned(-32000, e.0, None::<()>))?;
let (tx, rx) = tokio::sync::oneshot::channel();
let emb_arc = embedder.clone();
let text_cl = text.clone();
std::thread::spawn(move || {
let res = emb_arc.embed(&text_cl);
let _ = tx.send(res);
});
let vector = match rx.await {
Ok(Ok(v)) => v,
Ok(Err(e)) => return Err(jsonrpsee::types::ErrorObjectOwned::owned(-32000, e.0, None::<()>)),
Err(recv_err) => return Err(jsonrpsee::types::ErrorObjectOwned::owned(-32000, format!("embedding thread error: {}", recv_err), None::<()>)),
};
server.lance_store()
.map_err(|e| jsonrpsee::types::ErrorObjectOwned::owned(-32000, e.0, None::<()>))?
.store_vector(&name, &id, vector, meta.unwrap_or_default(), Some(text)).await
@@ -1087,10 +1099,21 @@ impl RpcServer for RpcServerImpl {
if !server.has_read_permission() {
return Err(jsonrpsee::types::ErrorObjectOwned::owned(-32000, "read permission denied", None::<()>));
}
// Resolve embedder and run embedding on a plain OS thread (avoid dropping any runtime in async context)
let embedder = server.get_embedder_for(&name)
.map_err(|e| jsonrpsee::types::ErrorObjectOwned::owned(-32000, e.0, None::<()>))?;
let qv = embedder.embed(&text)
.map_err(|e| jsonrpsee::types::ErrorObjectOwned::owned(-32000, e.0, None::<()>))?;
let (tx, rx) = tokio::sync::oneshot::channel();
let emb_arc = embedder.clone();
let text_cl = text.clone();
std::thread::spawn(move || {
let res = emb_arc.embed(&text_cl);
let _ = tx.send(res);
});
let qv = match rx.await {
Ok(Ok(v)) => v,
Ok(Err(e)) => return Err(jsonrpsee::types::ErrorObjectOwned::owned(-32000, e.0, None::<()>)),
Err(recv_err) => return Err(jsonrpsee::types::ErrorObjectOwned::owned(-32000, format!("embedding thread error: {}", recv_err), None::<()>)),
};
let results = server.lance_store()
.map_err(|e| jsonrpsee::types::ErrorObjectOwned::owned(-32000, e.0, None::<()>))?
.search_vectors(&name, qv, k, filter, return_fields).await