WIP2: implementing lancedb: created embedding abstraction, server-side per-dataset embedding config + updates RPC endpoints

This commit is contained in:
Maxime Van Hees
2025-09-29 13:17:34 +02:00
parent 6a4e2819bf
commit cf66f4c304
6 changed files with 595 additions and 99 deletions

View File

@@ -1,4 +1,4 @@
use crate::{error::DBError, protocol::Protocol, server::Server};
use crate::{error::DBError, protocol::Protocol, server::Server, embedding::{EmbeddingConfig, EmbeddingProvider}};
use tokio::time::{timeout, Duration};
use futures::future::select_all;
@@ -127,20 +127,20 @@ pub enum Cmd {
reducers: Vec<String>,
},
// LanceDB vector search commands
// LanceDB text-first commands (no user-provided vectors)
LanceCreate {
name: String,
dim: usize,
},
LanceStore {
LanceStoreText {
name: String,
id: String,
vector: Vec<f32>,
text: String,
meta: Vec<(String, String)>,
},
LanceSearch {
LanceSearchText {
name: String,
vector: Vec<f32>,
text: String,
k: usize,
filter: Option<String>,
return_fields: Option<Vec<String>>,
@@ -150,6 +150,16 @@ pub enum Cmd {
index_type: String,
params: Vec<(String, String)>,
},
// Embedding configuration per dataset
LanceEmbeddingConfigSet {
name: String,
provider: String,
model: String,
params: Vec<(String, String)>,
},
LanceEmbeddingConfigGet {
name: String,
},
LanceList,
LanceInfo {
name: String,
@@ -862,9 +872,9 @@ impl Cmd {
Cmd::LanceCreate { name, dim }
}
"lance.store" => {
// LANCE.STORE name ID id VECTOR v1 v2 ... [META k v ...]
// LANCE.STORE name ID <id> TEXT <text> [META k v ...]
if cmd.len() < 6 {
return Err(DBError("ERR LANCE.STORE requires: name ID <id> VECTOR v1 v2 ... [META k v ...]".to_string()));
return Err(DBError("ERR LANCE.STORE requires: name ID <id> TEXT <text> [META k v ...]".to_string()));
}
let name = cmd[1].clone();
let mut i = 2;
@@ -873,16 +883,16 @@ impl Cmd {
}
let id = cmd[i + 1].clone();
i += 2;
if i >= cmd.len() || cmd[i].to_uppercase() != "VECTOR" {
return Err(DBError("ERR LANCE.STORE requires VECTOR <f32...>".to_string()));
if i >= cmd.len() || cmd[i].to_uppercase() != "TEXT" {
return Err(DBError("ERR LANCE.STORE requires TEXT <text>".to_string()));
}
i += 1;
let mut vector: Vec<f32> = Vec::new();
while i < cmd.len() && cmd[i].to_uppercase() != "META" {
let v: f32 = cmd[i].parse().map_err(|_| DBError("ERR vector element must be a float32".to_string()))?;
vector.push(v);
i += 1;
if i >= cmd.len() {
return Err(DBError("ERR LANCE.STORE requires TEXT <text>".to_string()));
}
let text = cmd[i].clone();
i += 1;
let mut meta: Vec<(String, String)> = Vec::new();
if i < cmd.len() && cmd[i].to_uppercase() == "META" {
i += 1;
@@ -891,28 +901,28 @@ impl Cmd {
i += 2;
}
}
Cmd::LanceStore { name, id, vector, meta }
Cmd::LanceStoreText { name, id, text, meta }
}
"lance.search" => {
// LANCE.SEARCH name K k VECTOR v1 v2 ... [FILTER expr] [RETURN n fields...]
// LANCE.SEARCH name K <k> QUERY <text> [FILTER expr] [RETURN n fields...]
if cmd.len() < 6 {
return Err(DBError("ERR LANCE.SEARCH requires: name K <k> VECTOR v1 v2 ... [FILTER expr] [RETURN n fields...]".to_string()));
return Err(DBError("ERR LANCE.SEARCH requires: name K <k> QUERY <text> [FILTER expr] [RETURN n fields...]".to_string()));
}
let name = cmd[1].clone();
if cmd[2].to_uppercase() != "K" {
return Err(DBError("ERR LANCE.SEARCH requires K <k>".to_string()));
}
let k: usize = cmd[3].parse().map_err(|_| DBError("ERR K must be an integer".to_string()))?;
if cmd[4].to_uppercase() != "VECTOR" {
return Err(DBError("ERR LANCE.SEARCH requires VECTOR <f32...>".to_string()));
if cmd[4].to_uppercase() != "QUERY" {
return Err(DBError("ERR LANCE.SEARCH requires QUERY <text>".to_string()));
}
let mut i = 5;
let mut vector: Vec<f32> = Vec::new();
while i < cmd.len() && !["FILTER","RETURN"].contains(&cmd[i].to_uppercase().as_str()) {
let v: f32 = cmd[i].parse().map_err(|_| DBError("ERR vector element must be a float32".to_string()))?;
vector.push(v);
i += 1;
if i >= cmd.len() {
return Err(DBError("ERR LANCE.SEARCH requires QUERY <text>".to_string()));
}
let text = cmd[i].clone();
i += 1;
let mut filter: Option<String> = None;
let mut return_fields: Option<Vec<String>> = None;
while i < cmd.len() {
@@ -942,7 +952,7 @@ impl Cmd {
_ => { i += 1; }
}
}
Cmd::LanceSearch { name, vector, k, filter, return_fields }
Cmd::LanceSearchText { name, text, k, filter, return_fields }
}
"lance.createindex" => {
// LANCE.CREATEINDEX name TYPE t [PARAM k v ...]
@@ -962,6 +972,60 @@ impl Cmd {
}
Cmd::LanceCreateIndex { name, index_type, params }
}
"lance.embedding" => {
// LANCE.EMBEDDING CONFIG SET name PROVIDER p MODEL m [PARAM k v ...]
// LANCE.EMBEDDING CONFIG GET name
if cmd.len() < 3 || cmd[1].to_uppercase() != "CONFIG" {
return Err(DBError("ERR LANCE.EMBEDDING requires CONFIG subcommand".to_string()));
}
if cmd.len() >= 4 && cmd[2].to_uppercase() == "SET" {
if cmd.len() < 8 {
return Err(DBError("ERR LANCE.EMBEDDING CONFIG SET requires: SET name PROVIDER p MODEL m [PARAM k v ...]".to_string()));
}
let name = cmd[3].clone();
let mut i = 4;
let mut provider: Option<String> = None;
let mut model: Option<String> = None;
let mut params: Vec<(String, String)> = Vec::new();
while i < cmd.len() {
match cmd[i].to_uppercase().as_str() {
"PROVIDER" => {
if i + 1 >= cmd.len() {
return Err(DBError("ERR PROVIDER requires a value".to_string()));
}
provider = Some(cmd[i + 1].clone());
i += 2;
}
"MODEL" => {
if i + 1 >= cmd.len() {
return Err(DBError("ERR MODEL requires a value".to_string()));
}
model = Some(cmd[i + 1].clone());
i += 2;
}
"PARAM" => {
i += 1;
while i + 1 < cmd.len() {
params.push((cmd[i].clone(), cmd[i + 1].clone()));
i += 2;
}
}
_ => {
// Unknown token; break to avoid infinite loop
i += 1;
}
}
}
let provider = provider.ok_or_else(|| DBError("ERR missing PROVIDER".to_string()))?;
let model = model.ok_or_else(|| DBError("ERR missing MODEL".to_string()))?;
Cmd::LanceEmbeddingConfigSet { name, provider, model, params }
} else if cmd.len() == 4 && cmd[2].to_uppercase() == "GET" {
let name = cmd[3].clone();
Cmd::LanceEmbeddingConfigGet { name }
} else {
return Err(DBError("ERR LANCE.EMBEDDING CONFIG supports: SET ... | GET name".to_string()));
}
}
"lance.list" => {
if cmd.len() != 1 {
return Err(DBError("ERR LANCE.LIST takes no arguments".to_string()));
@@ -1070,8 +1134,10 @@ impl Cmd {
| Cmd::Command(..)
| Cmd::Info(..)
| Cmd::LanceCreate { .. }
| Cmd::LanceStore { .. }
| Cmd::LanceSearch { .. }
| Cmd::LanceStoreText { .. }
| Cmd::LanceSearchText { .. }
| Cmd::LanceEmbeddingConfigSet { .. }
| Cmd::LanceEmbeddingConfigGet { .. }
| Cmd::LanceCreateIndex { .. }
| Cmd::LanceList
| Cmd::LanceInfo { .. }
@@ -1104,8 +1170,10 @@ impl Cmd {
if !is_lance_backend {
match &self {
Cmd::LanceCreate { .. }
| Cmd::LanceStore { .. }
| Cmd::LanceSearch { .. }
| Cmd::LanceStoreText { .. }
| Cmd::LanceSearchText { .. }
| Cmd::LanceEmbeddingConfigSet { .. }
| Cmd::LanceEmbeddingConfigGet { .. }
| Cmd::LanceCreateIndex { .. }
| Cmd::LanceList
| Cmd::LanceInfo { .. }
@@ -1249,18 +1317,66 @@ impl Cmd {
Err(e) => Ok(Protocol::err(&e.0)),
}
}
Cmd::LanceStore { name, id, vector, meta } => {
Cmd::LanceEmbeddingConfigSet { name, provider, model, params } => {
if !server.has_write_permission() {
return Ok(Protocol::err("ERR write permission denied"));
}
let meta_map: std::collections::HashMap<String, String> = meta.into_iter().collect();
match server.lance_store()?.store_vector(&name, &id, vector, meta_map).await {
// Map provider string to enum
let p_lc = provider.to_lowercase();
let prov = match p_lc.as_str() {
"test-hash" | "testhash" => EmbeddingProvider::TestHash,
"fastembed" | "lancefastembed" => EmbeddingProvider::LanceFastEmbed,
"openai" | "lanceopenai" => EmbeddingProvider::LanceOpenAI,
other => EmbeddingProvider::LanceOther(other.to_string()),
};
let cfg = EmbeddingConfig {
provider: prov,
model,
params: params.into_iter().collect(),
};
match server.set_dataset_embedding_config(&name, &cfg) {
Ok(()) => Ok(Protocol::SimpleString("OK".to_string())),
Err(e) => Ok(Protocol::err(&e.0)),
}
}
Cmd::LanceSearch { name, vector, k, filter, return_fields } => {
match server.lance_store()?.search_vectors(&name, vector, k, filter, return_fields).await {
Cmd::LanceEmbeddingConfigGet { name } => {
match server.get_dataset_embedding_config(&name) {
Ok(cfg) => {
let mut arr = Vec::new();
arr.push(Protocol::BulkString("provider".to_string()));
arr.push(Protocol::BulkString(match cfg.provider {
EmbeddingProvider::TestHash => "test-hash".to_string(),
EmbeddingProvider::LanceFastEmbed => "lancefastembed".to_string(),
EmbeddingProvider::LanceOpenAI => "lanceopenai".to_string(),
EmbeddingProvider::LanceOther(ref s) => s.clone(),
}));
arr.push(Protocol::BulkString("model".to_string()));
arr.push(Protocol::BulkString(cfg.model.clone()));
arr.push(Protocol::BulkString("params".to_string()));
arr.push(Protocol::BulkString(serde_json::to_string(&cfg.params).unwrap_or_else(|_| "{}".to_string())));
Ok(Protocol::Array(arr))
}
Err(e) => Ok(Protocol::err(&e.0)),
}
}
Cmd::LanceStoreText { name, id, text, meta } => {
if !server.has_write_permission() {
return Ok(Protocol::err("ERR write permission denied"));
}
// Resolve embedder and embed text
let embedder = server.get_embedder_for(&name)?;
let vector = embedder.embed(&text)?;
let meta_map: std::collections::HashMap<String, String> = meta.into_iter().collect();
match server.lance_store()?.store_vector(&name, &id, vector, meta_map, Some(text)).await {
Ok(()) => Ok(Protocol::SimpleString("OK".to_string())),
Err(e) => Ok(Protocol::err(&e.0)),
}
}
Cmd::LanceSearchText { name, text, k, filter, return_fields } => {
// Resolve embedder and embed query text
let embedder = server.get_embedder_for(&name)?;
let qv = embedder.embed(&text)?;
match server.lance_store()?.search_vectors(&name, qv, k, filter, return_fields).await {
Ok(results) => {
// Encode as array of [id, score, [k1, v1, k2, v2, ...]]
let mut arr = Vec::new();