Files
herodb/specs/lance_implementation.md
2025-08-23 04:58:41 +02:00

23 KiB

Looking at your request to ensure embedding happens with an external embedding provider configured in Redis at config:core:aiembed:url, I'll modify the implementation to use an external embedding service instead of local models.

Updated Implementation with External Embedding Provider

1. Update Dependencies in Cargo.toml

[dependencies]
# ... existing dependencies ...
lance = "0.18"
lance-index = "0.18"
lance-linalg = "0.18"
arrow = "52"
arrow-array = "52"
arrow-schema = "52"
parquet = "52"
uuid = { version = "1.10", features = ["v4"] }
reqwest = { version = "0.11", features = ["json"] }
serde = { version = "1.0", features = ["derive"] }
serde_json = "1.0"
base64 = "0.22"
image = "0.25"

2. Create Enhanced Lance Module with External Embedding

Create src/lance_store.rs:

use std::collections::HashMap;
use std::path::PathBuf;
use std::sync::Arc;
use tokio::sync::RwLock;

use arrow::array::{Float32Array, StringArray, BinaryArray, ArrayRef};
use arrow::datatypes::{DataType, Field, Schema};
use arrow::record_batch::RecordBatch;
use lance::dataset::{Dataset, WriteParams, WriteMode};
use lance::index::vector::VectorIndexParams;
use lance_index::vector::pq::PQBuildParams;
use lance_index::vector::ivf::IvfBuildParams;

use serde::{Deserialize, Serialize};
use crate::error::DBError;
use crate::cmd::Protocol;

#[derive(Debug, Serialize, Deserialize)]
struct EmbeddingRequest {
    texts: Option<Vec<String>>,
    images: Option<Vec<String>>, // base64 encoded
    model: Option<String>,
}

#[derive(Debug, Serialize, Deserialize)]
struct EmbeddingResponse {
    embeddings: Vec<Vec<f32>>,
    model: String,
    usage: Option<HashMap<String, u32>>,
}

pub struct LanceStore {
    datasets: Arc<RwLock<HashMap<String, Arc<Dataset>>>>,
    data_dir: PathBuf,
    http_client: reqwest::Client,
}

impl LanceStore {
    pub async fn new(data_dir: PathBuf) -> Result<Self, DBError> {
        // Create data directory if it doesn't exist
        std::fs::create_dir_all(&data_dir)
            .map_err(|e| DBError(format!("Failed to create Lance data directory: {}", e)))?;
        
        let http_client = reqwest::Client::builder()
            .timeout(std::time::Duration::from_secs(30))
            .build()
            .map_err(|e| DBError(format!("Failed to create HTTP client: {}", e)))?;
        
        Ok(Self {
            datasets: Arc::new(RwLock::new(HashMap::new())),
            data_dir,
            http_client,
        })
    }
    
    /// Get embedding service URL from Redis config
    async fn get_embedding_url(&self, server: &crate::server::Server) -> Result<String, DBError> {
        // Get the embedding URL from Redis config
        let key = "config:core:aiembed:url";
        
        // Use HGET to retrieve the URL from Redis hash
        let cmd = crate::cmd::Cmd::HGet {
            key: key.to_string(),
            field: "url".to_string(),
        };
        
        // Execute command to get the config
        let result = cmd.run(server).await?;
        
        match result {
            Protocol::BulkString(url) => Ok(url),
            Protocol::SimpleString(url) => Ok(url),
            Protocol::Nil => Err(DBError(
                "Embedding service URL not configured. Set it with: HSET config:core:aiembed:url url <YOUR_EMBEDDING_SERVICE_URL>".to_string()
            )),
            _ => Err(DBError("Invalid embedding URL configuration".to_string())),
        }
    }
    
    /// Call external embedding service
    async fn call_embedding_service(
        &self,
        server: &crate::server::Server,
        texts: Option<Vec<String>>,
        images: Option<Vec<String>>,
    ) -> Result<Vec<Vec<f32>>, DBError> {
        let url = self.get_embedding_url(server).await?;
        
        let request = EmbeddingRequest {
            texts,
            images,
            model: None, // Let the service use its default
        };
        
        let response = self.http_client
            .post(&url)
            .json(&request)
            .send()
            .await
            .map_err(|e| DBError(format!("Failed to call embedding service: {}", e)))?;
        
        if !response.status().is_success() {
            let status = response.status();
            let error_text = response.text().await.unwrap_or_default();
            return Err(DBError(format!(
                "Embedding service returned error {}: {}", 
                status, error_text
            )));
        }
        
        let embedding_response: EmbeddingResponse = response
            .json()
            .await
            .map_err(|e| DBError(format!("Failed to parse embedding response: {}", e)))?;
        
        Ok(embedding_response.embeddings)
    }
    
    pub async fn embed_text(
        &self, 
        server: &crate::server::Server,
        texts: Vec<String>
    ) -> Result<Vec<Vec<f32>>, DBError> {
        if texts.is_empty() {
            return Ok(Vec::new());
        }
        
        self.call_embedding_service(server, Some(texts), None).await
    }
    
    pub async fn embed_image(
        &self,
        server: &crate::server::Server,
        image_bytes: Vec<u8>
    ) -> Result<Vec<f32>, DBError> {
        // Convert image bytes to base64
        let base64_image = base64::encode(&image_bytes);
        
        let embeddings = self.call_embedding_service(
            server, 
            None, 
            Some(vec![base64_image])
        ).await?;
        
        embeddings.into_iter()
            .next()
            .ok_or_else(|| DBError("No embedding returned for image".to_string()))
    }
    
    pub async fn create_dataset(
        &self,
        name: &str,
        schema: Schema,
    ) -> Result<(), DBError> {
        let dataset_path = self.data_dir.join(format!("{}.lance", name));
        
        // Create empty dataset with schema
        let write_params = WriteParams {
            mode: WriteMode::Create,
            ..Default::default()
        };
        
        // Create an empty RecordBatch with the schema
        let empty_batch = RecordBatch::new_empty(Arc::new(schema));
        let batches = vec![empty_batch];
        
        let dataset = Dataset::write(
            batches,
            dataset_path.to_str().unwrap(),
            Some(write_params)
        ).await
        .map_err(|e| DBError(format!("Failed to create dataset: {}", e)))?;
        
        let mut datasets = self.datasets.write().await;
        datasets.insert(name.to_string(), Arc::new(dataset));
        
        Ok(())
    }
    
    pub async fn write_vectors(
        &self,
        dataset_name: &str,
        vectors: Vec<Vec<f32>>,
        metadata: Option<HashMap<String, Vec<String>>>,
    ) -> Result<usize, DBError> {
        let dataset_path = self.data_dir.join(format!("{}.lance", dataset_name));
        
        // Open or get cached dataset
        let dataset = self.get_or_open_dataset(dataset_name).await?;
        
        // Build RecordBatch
        let num_vectors = vectors.len();
        if num_vectors == 0 {
            return Ok(0);
        }
        
        let dim = vectors.first()
            .ok_or_else(|| DBError("Empty vectors".to_string()))?
            .len();
        
        // Flatten vectors
        let flat_vectors: Vec<f32> = vectors.into_iter().flatten().collect();
        let vector_array = Float32Array::from(flat_vectors);
        let vector_array = arrow::array::FixedSizeListArray::try_new_from_values(
            vector_array, 
            dim as i32
        ).map_err(|e| DBError(format!("Failed to create vector array: {}", e)))?;
        
        let mut arrays: Vec<ArrayRef> = vec![Arc::new(vector_array)];
        let mut fields = vec![Field::new(
            "vector",
            DataType::FixedSizeList(
                Arc::new(Field::new("item", DataType::Float32, true)),
                dim as i32
            ),
            false
        )];
        
        // Add metadata columns if provided
        if let Some(metadata) = metadata {
            for (key, values) in metadata {
                if values.len() != num_vectors {
                    return Err(DBError(format!(
                        "Metadata field '{}' has {} values but expected {}", 
                        key, values.len(), num_vectors
                    )));
                }
                let array = StringArray::from(values);
                arrays.push(Arc::new(array));
                fields.push(Field::new(&key, DataType::Utf8, true));
            }
        }
        
        let schema = Arc::new(Schema::new(fields));
        let batch = RecordBatch::try_new(schema, arrays)
            .map_err(|e| DBError(format!("Failed to create RecordBatch: {}", e)))?;
        
        // Append to dataset
        let write_params = WriteParams {
            mode: WriteMode::Append,
            ..Default::default()
        };
        
        Dataset::write(
            vec![batch],
            dataset_path.to_str().unwrap(),
            Some(write_params)
        ).await
        .map_err(|e| DBError(format!("Failed to write to dataset: {}", e)))?;
        
        // Refresh cached dataset
        let mut datasets = self.datasets.write().await;
        datasets.remove(dataset_name);
        
        Ok(num_vectors)
    }
    
    pub async fn search_vectors(
        &self,
        dataset_name: &str,
        query_vector: Vec<f32>,
        k: usize,
        nprobes: Option<usize>,
        refine_factor: Option<usize>,
    ) -> Result<Vec<(f32, HashMap<String, String>)>, DBError> {
        let dataset = self.get_or_open_dataset(dataset_name).await?;
        
        // Build query
        let mut query = dataset.scan();
        query = query.nearest(
            "vector",
            &query_vector,
            k,
        ).map_err(|e| DBError(format!("Failed to build search query: {}", e)))?;
        
        if let Some(nprobes) = nprobes {
            query = query.nprobes(nprobes);
        }
        
        if let Some(refine) = refine_factor {
            query = query.refine_factor(refine);
        }
        
        // Execute search
        let results = query
            .try_into_stream()
            .await
            .map_err(|e| DBError(format!("Failed to execute search: {}", e)))?
            .try_collect::<Vec<_>>()
            .await
            .map_err(|e| DBError(format!("Failed to collect results: {}", e)))?;
        
        // Process results
        let mut output = Vec::new();
        for batch in results {
            // Get distances
            let distances = batch
                .column_by_name("_distance")
                .ok_or_else(|| DBError("No distance column".to_string()))?
                .as_any()
                .downcast_ref::<Float32Array>()
                .ok_or_else(|| DBError("Invalid distance type".to_string()))?;
            
            // Get metadata
            for i in 0..batch.num_rows() {
                let distance = distances.value(i);
                let mut metadata = HashMap::new();
                
                for field in batch.schema().fields() {
                    if field.name() != "vector" && field.name() != "_distance" {
                        if let Some(col) = batch.column_by_name(field.name()) {
                            if let Some(str_array) = col.as_any().downcast_ref::<StringArray>() {
                                if !str_array.is_null(i) {
                                    metadata.insert(
                                        field.name().to_string(),
                                        str_array.value(i).to_string()
                                    );
                                }
                            }
                        }
                    }
                }
                
                output.push((distance, metadata));
            }
        }
        
        Ok(output)
    }
    
    pub async fn store_multimodal(
        &self,
        server: &crate::server::Server,
        dataset_name: &str,
        text: Option<String>,
        image_bytes: Option<Vec<u8>>,
        metadata: HashMap<String, String>,
    ) -> Result<String, DBError> {
        // Generate ID
        let id = uuid::Uuid::new_v4().to_string();
        
        // Generate embeddings using external service
        let embedding = if let Some(text) = text.as_ref() {
            self.embed_text(server, vec![text.clone()]).await?
                .into_iter()
                .next()
                .ok_or_else(|| DBError("No embedding returned".to_string()))?
        } else if let Some(img) = image_bytes.as_ref() {
            self.embed_image(server, img.clone()).await?
        } else {
            return Err(DBError("No text or image provided".to_string()));
        };
        
        // Prepare metadata
        let mut full_metadata = metadata;
        full_metadata.insert("id".to_string(), id.clone());
        if let Some(text) = text {
            full_metadata.insert("text".to_string(), text);
        }
        if let Some(img) = image_bytes {
            full_metadata.insert("image_base64".to_string(), base64::encode(img));
        }
        
        // Convert metadata to column vectors
        let mut metadata_cols = HashMap::new();
        for (key, value) in full_metadata {
            metadata_cols.insert(key, vec![value]);
        }
        
        // Write to dataset
        self.write_vectors(dataset_name, vec![embedding], Some(metadata_cols)).await?;
        
        Ok(id)
    }
    
    pub async fn search_with_text(
        &self,
        server: &crate::server::Server,
        dataset_name: &str,
        query_text: String,
        k: usize,
        nprobes: Option<usize>,
        refine_factor: Option<usize>,
    ) -> Result<Vec<(f32, HashMap<String, String>)>, DBError> {
        // Embed the query text using external service
        let embeddings = self.embed_text(server, vec![query_text]).await?;
        let query_vector = embeddings.into_iter()
            .next()
            .ok_or_else(|| DBError("No embedding returned for query".to_string()))?;
        
        // Search with the embedding
        self.search_vectors(dataset_name, query_vector, k, nprobes, refine_factor).await
    }
    
    pub async fn create_index(
        &self,
        dataset_name: &str,
        index_type: &str,
        num_partitions: Option<usize>,
        num_sub_vectors: Option<usize>,
    ) -> Result<(), DBError> {
        let dataset = self.get_or_open_dataset(dataset_name).await?;
        
        let mut params = VectorIndexParams::default();
        
        match index_type.to_uppercase().as_str() {
            "IVF_PQ" => {
                params.ivf = IvfBuildParams {
                    num_partitions: num_partitions.unwrap_or(256),
                    ..Default::default()
                };
                params.pq = PQBuildParams {
                    num_sub_vectors: num_sub_vectors.unwrap_or(16),
                    ..Default::default()
                };
            }
            _ => return Err(DBError(format!("Unsupported index type: {}", index_type))),
        }
        
        dataset.create_index(
            &["vector"],
            lance::index::IndexType::Vector,
            None,
            &params,
            true
        ).await
        .map_err(|e| DBError(format!("Failed to create index: {}", e)))?;
        
        Ok(())
    }
    
    async fn get_or_open_dataset(&self, name: &str) -> Result<Arc<Dataset>, DBError> {
        let mut datasets = self.datasets.write().await;
        
        if let Some(dataset) = datasets.get(name) {
            return Ok(dataset.clone());
        }
        
        let dataset_path = self.data_dir.join(format!("{}.lance", name));
        if !dataset_path.exists() {
            return Err(DBError(format!("Dataset '{}' does not exist", name)));
        }
        
        let dataset = Dataset::open(dataset_path.to_str().unwrap())
            .await
            .map_err(|e| DBError(format!("Failed to open dataset: {}", e)))?;
        
        let dataset = Arc::new(dataset);
        datasets.insert(name.to_string(), dataset.clone());
        
        Ok(dataset)
    }
    
    pub async fn list_datasets(&self) -> Result<Vec<String>, DBError> {
        let mut datasets = Vec::new();
        
        let entries = std::fs::read_dir(&self.data_dir)
            .map_err(|e| DBError(format!("Failed to read data directory: {}", e)))?;
        
        for entry in entries {
            let entry = entry.map_err(|e| DBError(format!("Failed to read entry: {}", e)))?;
            let path = entry.path();
            
            if path.is_dir() {
                if let Some(name) = path.file_name() {
                    if let Some(name_str) = name.to_str() {
                        if name_str.ends_with(".lance") {
                            let dataset_name = name_str.trim_end_matches(".lance");
                            datasets.push(dataset_name.to_string());
                        }
                    }
                }
            }
        }
        
        Ok(datasets)
    }
    
    pub async fn drop_dataset(&self, name: &str) -> Result<(), DBError> {
        // Remove from cache
        let mut datasets = self.datasets.write().await;
        datasets.remove(name);
        
        // Delete from disk
        let dataset_path = self.data_dir.join(format!("{}.lance", name));
        if dataset_path.exists() {
            std::fs::remove_dir_all(dataset_path)
                .map_err(|e| DBError(format!("Failed to delete dataset: {}", e)))?;
        }
        
        Ok(())
    }
    
    pub async fn get_dataset_info(&self, name: &str) -> Result<HashMap<String, String>, DBError> {
        let dataset = self.get_or_open_dataset(name).await?;
        
        let mut info = HashMap::new();
        info.insert("name".to_string(), name.to_string());
        info.insert("version".to_string(), dataset.version().to_string());
        info.insert("num_rows".to_string(), dataset.count_rows().await?.to_string());
        
        // Get schema info
        let schema = dataset.schema();
        let fields: Vec<String> = schema.fields()
            .iter()
            .map(|f| format!("{}:{}", f.name(), f.data_type()))
            .collect();
        info.insert("schema".to_string(), fields.join(", "));
        
        Ok(info)
    }
}

3. Update Command Implementations

Update the command implementations to pass the server reference for embedding service access:

// In cmd.rs, update the lance command implementations

async fn lance_store_cmd(
    server: &Server,
    dataset: &str,
    text: Option<String>,
    image_base64: Option<String>,
    metadata: HashMap<String, String>,
) -> Result<Protocol, DBError> {
    let lance_store = server.lance_store()?;
    
    // Decode image if provided
    let image_bytes = if let Some(b64) = image_base64 {
        Some(base64::decode(b64).map_err(|e| 
            DBError(format!("Invalid base64 image: {}", e)))?)
    } else {
        None
    };
    
    // Pass server reference for embedding service access
    let id = lance_store.store_multimodal(
        server,  // Pass server to access Redis config
        dataset,
        text,
        image_bytes,
        metadata,
    ).await?;
    
    Ok(Protocol::BulkString(id))
}

async fn lance_embed_text_cmd(
    server: &Server,
    texts: &[String],
) -> Result<Protocol, DBError> {
    let lance_store = server.lance_store()?;
    
    // Pass server reference for embedding service access
    let embeddings = lance_store.embed_text(server, texts.to_vec()).await?;
    
    // Return as array of vectors
    let mut output = Vec::new();
    for embedding in embeddings {
        let vector_str = format!("[{}]", 
            embedding.iter()
                .map(|f| f.to_string())
                .collect::<Vec<_>>()
                .join(",")
        );
        output.push(Protocol::BulkString(vector_str));
    }
    
    Ok(Protocol::Array(output))
}

async fn lance_search_text_cmd(
    server: &Server,
    dataset: &str,
    query_text: &str,
    k: usize,
    nprobes: Option<usize>,
    refine_factor: Option<usize>,
) -> Result<Protocol, DBError> {
    let lance_store = server.lance_store()?;
    
    // Search using text query (will be embedded automatically)
    let results = lance_store.search_with_text(
        server,
        dataset,
        query_text.to_string(),
        k,
        nprobes,
        refine_factor,
    ).await?;
    
    // Format results
    let mut output = Vec::new();
    for (distance, metadata) in results {
        let metadata_json = serde_json::to_string(&metadata)
            .unwrap_or_else(|_| "{}".to_string());
        
        output.push(Protocol::Array(vec![
            Protocol::BulkString(distance.to_string()),
            Protocol::BulkString(metadata_json),
        ]));
    }
    
    Ok(Protocol::Array(output))
}

// Add new command for text-based search
pub enum Cmd {
    // ... existing commands ...
    LanceSearchText {
        dataset: String,
        query_text: String,
        k: usize,
        nprobes: Option<usize>,
        refine_factor: Option<usize>,
    },
}

Usage Examples

1. Configure the Embedding Service

First, users need to configure the embedding service URL:

# Configure the embedding service endpoint
redis-cli> HSET config:core:aiembed:url url "http://localhost:8000/embeddings"
OK

# Or use a cloud service
redis-cli> HSET config:core:aiembed:url url "https://api.openai.com/v1/embeddings"
OK

2. Use Lance Commands with Automatic External Embedding

# Create a dataset
redis-cli> LANCE.CREATE products DIM 1536 SCHEMA name:string price:float category:string
OK

# Store text with automatic embedding (calls external service)
redis-cli> LANCE.STORE products TEXT "Wireless noise-canceling headphones with 30-hour battery" name:AirPods price:299.99 category:Electronics
"uuid-123-456"

# Search using text query (automatically embeds the query)
redis-cli> LANCE.SEARCH.TEXT products "best headphones for travel" K 5
1) "0.92" 
2) "{\"id\":\"uuid-123\",\"name\":\"AirPods\",\"price\":\"299.99\"}"

# Get embeddings directly
redis-cli> LANCE.EMBED.TEXT "This text will be embedded"
1) "[0.123, 0.456, 0.789, ...]"

External Embedding Service API Specification

The external embedding service should accept POST requests with this format:

// Request
{
  "texts": ["text1", "text2"],  // Optional
  "images": ["base64_img1"],    // Optional
  "model": "text-embedding-ada-002"  // Optional
}

// Response
{
  "embeddings": [[0.1, 0.2, ...], [0.3, 0.4, ...]],
  "model": "text-embedding-ada-002",
  "usage": {
    "prompt_tokens": 100,
    "total_tokens": 100
  }
}

Error Handling

The implementation includes comprehensive error handling:

  1. Missing Configuration: Clear error message if embedding URL not configured
  2. Service Failures: Graceful handling of embedding service errors
  3. Timeout Protection: 30-second timeout for embedding requests
  4. Retry Logic: Could be added for resilience

Benefits of This Approach

  1. Flexibility: Supports any embedding service with compatible API
  2. Cost Control: Use your preferred embedding provider
  3. Scalability: Embedding service can be scaled independently
  4. Consistency: All embeddings use the same configured service
  5. Security: API keys and endpoints stored securely in Redis

This implementation ensures that all embedding operations go through the external service configured in Redis, providing a clean separation between the vector database functionality and the embedding generation.

TODO EXTRA:

  • secret for the embedding service API key