WIP4 implementation lanceDB: removed blocking Tokio runtime usage during embeddings and isolated all embedding work off the async runtime
This commit is contained in:
30
src/cmd.rs
30
src/cmd.rs
@@ -1363,9 +1363,20 @@ impl Cmd {
|
||||
if !server.has_write_permission() {
|
||||
return Ok(Protocol::err("ERR write permission denied"));
|
||||
}
|
||||
// Resolve embedder and embed text
|
||||
// Resolve embedder and embed text on a plain OS thread to avoid tokio runtime panics from reqwest::blocking
|
||||
let embedder = server.get_embedder_for(&name)?;
|
||||
let vector = embedder.embed(&text)?;
|
||||
let (tx, rx) = tokio::sync::oneshot::channel();
|
||||
let emb_arc = embedder.clone();
|
||||
let text_cl = text.clone();
|
||||
std::thread::spawn(move || {
|
||||
let res = emb_arc.embed(&text_cl);
|
||||
let _ = tx.send(res);
|
||||
});
|
||||
let vector = match rx.await {
|
||||
Ok(Ok(v)) => v,
|
||||
Ok(Err(e)) => return Ok(Protocol::err(&e.0)),
|
||||
Err(recv_err) => return Ok(Protocol::err(&format!("ERR embedding thread error: {}", recv_err))),
|
||||
};
|
||||
let meta_map: std::collections::HashMap<String, String> = meta.into_iter().collect();
|
||||
match server.lance_store()?.store_vector(&name, &id, vector, meta_map, Some(text)).await {
|
||||
Ok(()) => Ok(Protocol::SimpleString("OK".to_string())),
|
||||
@@ -1373,9 +1384,20 @@ impl Cmd {
|
||||
}
|
||||
}
|
||||
Cmd::LanceSearchText { name, text, k, filter, return_fields } => {
|
||||
// Resolve embedder and embed query text
|
||||
// Resolve embedder and embed query text on a plain OS thread
|
||||
let embedder = server.get_embedder_for(&name)?;
|
||||
let qv = embedder.embed(&text)?;
|
||||
let (tx, rx) = tokio::sync::oneshot::channel();
|
||||
let emb_arc = embedder.clone();
|
||||
let text_cl = text.clone();
|
||||
std::thread::spawn(move || {
|
||||
let res = emb_arc.embed(&text_cl);
|
||||
let _ = tx.send(res);
|
||||
});
|
||||
let qv = match rx.await {
|
||||
Ok(Ok(v)) => v,
|
||||
Ok(Err(e)) => return Ok(Protocol::err(&e.0)),
|
||||
Err(recv_err) => return Ok(Protocol::err(&format!("ERR embedding thread error: {}", recv_err))),
|
||||
};
|
||||
match server.lance_store()?.search_vectors(&name, qv, k, filter, return_fields).await {
|
||||
Ok(results) => {
|
||||
// Encode as array of [id, score, [k1, v1, k2, v2, ...]]
|
||||
|
||||
@@ -23,8 +23,7 @@ use crate::error::DBError;
|
||||
|
||||
// Networking for OpenAI/Azure
|
||||
use std::time::Duration;
|
||||
use reqwest::blocking::Client;
|
||||
use reqwest::header::{HeaderMap, HeaderName, HeaderValue, CONTENT_TYPE, AUTHORIZATION};
|
||||
use ureq::{Agent, AgentBuilder};
|
||||
use serde_json::json;
|
||||
|
||||
/// Provider identifiers. Extend as needed to mirror LanceDB-supported providers.
|
||||
@@ -132,10 +131,9 @@ impl Embedder for TestHashEmbedder {
|
||||
struct OpenAIEmbedder {
|
||||
model: String,
|
||||
dim: usize,
|
||||
client: Client,
|
||||
agent: Agent,
|
||||
endpoint: String,
|
||||
auth_header_name: HeaderName,
|
||||
auth_header_value: HeaderValue,
|
||||
headers: Vec<(String, String)>,
|
||||
use_azure: bool,
|
||||
}
|
||||
|
||||
@@ -184,40 +182,33 @@ impl OpenAIEmbedder {
|
||||
.unwrap_or_else(|| "https://api.openai.com/v1/embeddings".to_string())
|
||||
};
|
||||
|
||||
// Determine expected dimension:
|
||||
// - Prefer params["dim"] or params["dimensions"]
|
||||
// - Else default to 1536 (common for text-embedding-3-small; callers should override if needed)
|
||||
// Determine expected dimension (default 1536 for text-embedding-3-small; callers should override if needed)
|
||||
let dim = cfg
|
||||
.get_param_usize("dim")
|
||||
.or_else(|| cfg.get_param_usize("dimensions"))
|
||||
.unwrap_or(1536);
|
||||
|
||||
// Build default headers
|
||||
let mut headers = HeaderMap::new();
|
||||
headers.insert(CONTENT_TYPE, HeaderValue::from_static("application/json"));
|
||||
let (auth_name, auth_val) = if use_azure {
|
||||
let name = HeaderName::from_static("api-key");
|
||||
let val = HeaderValue::from_str(&api_key)
|
||||
.map_err(|_| DBError("Invalid API key header value".into()))?;
|
||||
(name, val)
|
||||
} else {
|
||||
let bearer = format!("Bearer {}", api_key);
|
||||
(AUTHORIZATION, HeaderValue::from_str(&bearer).map_err(|_| DBError("Invalid Authorization header".into()))?)
|
||||
};
|
||||
// Build an HTTP agent with timeouts (blocking; no tokio runtime involved)
|
||||
let agent = AgentBuilder::new()
|
||||
.timeout_read(Duration::from_secs(30))
|
||||
.timeout_write(Duration::from_secs(30))
|
||||
.build();
|
||||
|
||||
let client = Client::builder()
|
||||
.timeout(Duration::from_secs(30))
|
||||
.default_headers(headers)
|
||||
.build()
|
||||
.map_err(|e| DBError(format!("Failed to build HTTP client: {}", e)))?;
|
||||
// Headers
|
||||
let mut headers: Vec<(String, String)> = Vec::new();
|
||||
headers.push(("Content-Type".to_string(), "application/json".to_string()));
|
||||
if use_azure {
|
||||
headers.push(("api-key".to_string(), api_key));
|
||||
} else {
|
||||
headers.push(("Authorization".to_string(), format!("Bearer {}", api_key)));
|
||||
}
|
||||
|
||||
Ok(Self {
|
||||
model: cfg.model.clone(),
|
||||
dim,
|
||||
client,
|
||||
agent,
|
||||
endpoint,
|
||||
auth_header_name: auth_name,
|
||||
auth_header_value: auth_val,
|
||||
headers,
|
||||
use_azure,
|
||||
})
|
||||
}
|
||||
@@ -237,21 +228,26 @@ impl OpenAIEmbedder {
|
||||
.insert("dimensions".to_string(), json!(self.dim));
|
||||
}
|
||||
|
||||
let mut req = self.client.post(&self.endpoint);
|
||||
// Add auth header dynamically
|
||||
req = req.header(self.auth_header_name.clone(), self.auth_header_value.clone());
|
||||
|
||||
let resp = req
|
||||
.json(&body)
|
||||
.send()
|
||||
.map_err(|e| DBError(format!("HTTP request failed: {}", e)))?;
|
||||
if !resp.status().is_success() {
|
||||
let code = resp.status();
|
||||
let text = resp.text().unwrap_or_default();
|
||||
return Err(DBError(format!("Embeddings API error {}: {}", code, text)));
|
||||
// Build request
|
||||
let mut req = self.agent.post(&self.endpoint);
|
||||
for (k, v) in &self.headers {
|
||||
req = req.set(k, v);
|
||||
}
|
||||
let val: serde_json::Value = resp
|
||||
.json()
|
||||
|
||||
// Send and handle errors
|
||||
let resp = req.send_json(body);
|
||||
let text = match resp {
|
||||
Ok(r) => r
|
||||
.into_string()
|
||||
.map_err(|e| DBError(format!("Failed to read embeddings response: {}", e)))?,
|
||||
Err(ureq::Error::Status(code, r)) => {
|
||||
let body = r.into_string().unwrap_or_default();
|
||||
return Err(DBError(format!("Embeddings API error {}: {}", code, body)));
|
||||
}
|
||||
Err(e) => return Err(DBError(format!("HTTP request failed: {}", e))),
|
||||
};
|
||||
|
||||
let val: serde_json::Value = serde_json::from_str(&text)
|
||||
.map_err(|e| DBError(format!("Invalid JSON from embeddings API: {}", e)))?;
|
||||
|
||||
let data = val
|
||||
|
||||
31
src/rpc.rs
31
src/rpc.rs
@@ -1057,10 +1057,22 @@ impl RpcServer for RpcServerImpl {
|
||||
if !server.has_write_permission() {
|
||||
return Err(jsonrpsee::types::ErrorObjectOwned::owned(-32000, "write permission denied", None::<()>));
|
||||
}
|
||||
// Resolve embedder and run blocking embedding off the async runtime
|
||||
// Resolve embedder and run embedding on a plain OS thread (avoid dropping any runtime in async context)
|
||||
let embedder = server.get_embedder_for(&name)
|
||||
.map_err(|e| jsonrpsee::types::ErrorObjectOwned::owned(-32000, e.0, None::<()>))?;
|
||||
let vector = embedder.embed(&text)
|
||||
.map_err(|e| jsonrpsee::types::ErrorObjectOwned::owned(-32000, e.0, None::<()>))?;
|
||||
let (tx, rx) = tokio::sync::oneshot::channel();
|
||||
let emb_arc = embedder.clone();
|
||||
let text_cl = text.clone();
|
||||
std::thread::spawn(move || {
|
||||
let res = emb_arc.embed(&text_cl);
|
||||
let _ = tx.send(res);
|
||||
});
|
||||
let vector = match rx.await {
|
||||
Ok(Ok(v)) => v,
|
||||
Ok(Err(e)) => return Err(jsonrpsee::types::ErrorObjectOwned::owned(-32000, e.0, None::<()>)),
|
||||
Err(recv_err) => return Err(jsonrpsee::types::ErrorObjectOwned::owned(-32000, format!("embedding thread error: {}", recv_err), None::<()>)),
|
||||
};
|
||||
server.lance_store()
|
||||
.map_err(|e| jsonrpsee::types::ErrorObjectOwned::owned(-32000, e.0, None::<()>))?
|
||||
.store_vector(&name, &id, vector, meta.unwrap_or_default(), Some(text)).await
|
||||
@@ -1087,10 +1099,21 @@ impl RpcServer for RpcServerImpl {
|
||||
if !server.has_read_permission() {
|
||||
return Err(jsonrpsee::types::ErrorObjectOwned::owned(-32000, "read permission denied", None::<()>));
|
||||
}
|
||||
// Resolve embedder and run embedding on a plain OS thread (avoid dropping any runtime in async context)
|
||||
let embedder = server.get_embedder_for(&name)
|
||||
.map_err(|e| jsonrpsee::types::ErrorObjectOwned::owned(-32000, e.0, None::<()>))?;
|
||||
let qv = embedder.embed(&text)
|
||||
.map_err(|e| jsonrpsee::types::ErrorObjectOwned::owned(-32000, e.0, None::<()>))?;
|
||||
let (tx, rx) = tokio::sync::oneshot::channel();
|
||||
let emb_arc = embedder.clone();
|
||||
let text_cl = text.clone();
|
||||
std::thread::spawn(move || {
|
||||
let res = emb_arc.embed(&text_cl);
|
||||
let _ = tx.send(res);
|
||||
});
|
||||
let qv = match rx.await {
|
||||
Ok(Ok(v)) => v,
|
||||
Ok(Err(e)) => return Err(jsonrpsee::types::ErrorObjectOwned::owned(-32000, e.0, None::<()>)),
|
||||
Err(recv_err) => return Err(jsonrpsee::types::ErrorObjectOwned::owned(-32000, format!("embedding thread error: {}", recv_err), None::<()>)),
|
||||
};
|
||||
let results = server.lance_store()
|
||||
.map_err(|e| jsonrpsee::types::ErrorObjectOwned::owned(-32000, e.0, None::<()>))?
|
||||
.search_vectors(&name, qv, k, filter, return_fields).await
|
||||
|
||||
Reference in New Issue
Block a user