diff --git a/docs/local_embedder_full_example.md b/docs/local_embedder_full_example.md new file mode 100644 index 0000000..c3c4eae --- /dev/null +++ b/docs/local_embedder_full_example.md @@ -0,0 +1,988 @@ +# HeroDB Embedding Models: Complete Tutorial + +This tutorial demonstrates how to use embedding models with HeroDB for vector search, covering both local self-hosted models and OpenAI's API. + +## Table of Contents +- [Prerequisites](#prerequisites) +- [Scenario 1: Local Embedding Model](#scenario-1-local-embedding-model-testing) +- [Scenario 2: OpenAI API](#scenario-2-openai-api) +- [Scenario 3: Deterministic Test Embedder](#scenario-3-deterministic-test-embedder-no-network) +- [Troubleshooting](#troubleshooting) + +--- + +## Prerequisites + +### Start HeroDB Server + +Build and start HeroDB with RPC enabled: + +```bash +cargo build --release +./target/release/herodb --dir ./data --admin-secret my-admin-secret --enable-rpc --rpc-port 8080 +``` + +This starts: +- Redis-compatible server on port 6379 +- JSON-RPC server on port 8080 + +### Client Tools + +For Redis-like commands: +```bash +redis-cli -p 6379 +``` + +For JSON-RPC calls, use `curl`: +```bash +curl -X POST http://localhost:8080 \ + -H "Content-Type: application/json" \ + -d '{"jsonrpc":"2.0","id":1,"method":"herodb_METHOD","params":[...]}' +``` + +--- + +## Scenario 1: Local Embedding Model (Testing) + +Run your own embedding service locally for development, testing, or privacy. + +### Option A: Python Mock Server (Simplest) + +This creates a minimal OpenAI-compatible embedding server for testing. + +**1. Create `mock_embedder.py`:** + +```python +from flask import Flask, request, jsonify +import numpy as np + +app = Flask(__name__) + +@app.route('/v1/embeddings', methods=['POST']) +def embeddings(): + """OpenAI-compatible embeddings endpoint""" + data = request.json + inputs = data.get('input', []) + + # Handle both single string and array + if isinstance(inputs, str): + inputs = [inputs] + + # Generate deterministic 768-dim embeddings (hash-based) + embeddings = [] + for text in inputs: + # Simple hash to vector (deterministic) + vec = np.zeros(768) + for i, char in enumerate(text[:768]): + vec[i % 768] += ord(char) / 255.0 + + # L2 normalize + norm = np.linalg.norm(vec) + if norm > 0: + vec = vec / norm + + embeddings.append(vec.tolist()) + + return jsonify({ + "data": [{"embedding": emb, "index": i} for i, emb in enumerate(embeddings)], + "model": data.get('model', 'mock-local'), + "usage": {"total_tokens": sum(len(t) for t in inputs)} + }) + +if __name__ == '__main__': + print("Starting mock embedding server on http://127.0.0.1:8081") + app.run(host='127.0.0.1', port=8081, debug=False) +``` + +**2. Install dependencies and run:** + +```bash +pip install flask numpy +python mock_embedder.py +``` + +Output: `Starting mock embedding server on http://127.0.0.1:8081` + +**3. Test the server (optional):** + +```bash +curl -X POST http://127.0.0.1:8081/v1/embeddings \ + -H "Content-Type: application/json" \ + -d '{"input":["hello world"],"model":"test"}' +``` + +You should see a JSON response with a 768-dimensional embedding. + +### End-to-End Example with Local Model + +**Step 1: Create a Lance database** + +JSON-RPC: +```json +{ + "jsonrpc": "2.0", + "id": 1, + "method": "herodb_createDatabase", + "params": [ + "Lance", + { "name": "local-vectors", "storage_path": null, "max_size": null, "redis_version": null }, + null + ] +} +``` + +Expected response: +```json +{"jsonrpc":"2.0","id":1,"result":1} +``` + +The database ID is `1`. + +**Step 2: Configure embedding for the dataset** + +JSON-RPC: +```json +{ + "jsonrpc": "2.0", + "id": 2, + "method": "herodb_lanceSetEmbeddingConfig", + "params": [ + 1, + "products", + { + "provider": "openai", + "model": "mock-local", + "dim": 768, + "endpoint": "http://127.0.0.1:8081/v1/embeddings", + "headers": { + "Authorization": "Bearer dummy" + }, + "timeout_ms": 30000 + } + ] +} +``` + +Redis-like: +```bash +redis-cli -p 6379 +SELECT 1 +LANCE.EMBEDDING CONFIG SET products PROVIDER openai MODEL mock-local DIM 768 ENDPOINT http://127.0.0.1:8081/v1/embeddings HEADER Authorization "Bearer dummy" TIMEOUTMS 30000 +``` + +Expected response: +```json +{"jsonrpc":"2.0","id":2,"result":true} +``` + +**Step 3: Verify configuration** + +JSON-RPC: +```json +{ + "jsonrpc": "2.0", + "id": 3, + "method": "herodb_lanceGetEmbeddingConfig", + "params": [1, "products"] +} +``` + +Redis-like: +```bash +LANCE.EMBEDDING CONFIG GET products +``` + +Expected: Returns your configuration with provider, model, dim, endpoint, etc. + +**Step 4: Insert product data** + +JSON-RPC (item 1): +```json +{ + "jsonrpc": "2.0", + "id": 4, + "method": "herodb_lanceStoreText", + "params": [ + 1, + "products", + "item-1", + "Waterproof hiking boots with ankle support and aggressive tread", + { "brand": "TrailMax", "category": "footwear", "price": "129.99" } + ] +} +``` + +Redis-like: +```bash +LANCE.STORE products ID item-1 TEXT "Waterproof hiking boots with ankle support and aggressive tread" META brand TrailMax category footwear price 129.99 +``` + +JSON-RPC (item 2): +```json +{ + "jsonrpc": "2.0", + "id": 5, + "method": "herodb_lanceStoreText", + "params": [ + 1, + "products", + "item-2", + "Lightweight running shoes with breathable mesh upper", + { "brand": "SpeedFit", "category": "footwear", "price": "89.99" } + ] +} +``` + +JSON-RPC (item 3): +```json +{ + "jsonrpc": "2.0", + "id": 6, + "method": "herodb_lanceStoreText", + "params": [ + 1, + "products", + "item-3", + "Insulated winter jacket with removable hood and multiple pockets", + { "brand": "WarmTech", "category": "outerwear", "price": "199.99" } + ] +} +``` + +JSON-RPC (item 4): +```json +{ + "jsonrpc": "2.0", + "id": 7, + "method": "herodb_lanceStoreText", + "params": [ + 1, + "products", + "item-4", + "Camping tent for 4 people with waterproof rainfly", + { "brand": "OutdoorPro", "category": "camping", "price": "249.99" } + ] +} +``` + +Expected response for each: `{"jsonrpc":"2.0","id":N,"result":true}` + +**Step 5: Search by text query** + +JSON-RPC: +```json +{ + "jsonrpc": "2.0", + "id": 8, + "method": "herodb_lanceSearchText", + "params": [ + 1, + "products", + "boots for hiking in wet conditions", + 3, + null, + ["brand", "category", "price"] + ] +} +``` + +Redis-like: +```bash +LANCE.SEARCH products K 3 QUERY "boots for hiking in wet conditions" RETURN 3 brand category price +``` + +Expected response: +```json +{ + "jsonrpc": "2.0", + "id": 8, + "result": { + "results": [ + { + "id": "item-1", + "score": 0.234, + "meta": { + "brand": "TrailMax", + "category": "footwear", + "price": "129.99" + } + }, + ... + ] + } +} +``` + +**Step 6: Search with metadata filter** + +JSON-RPC: +```json +{ + "jsonrpc": "2.0", + "id": 9, + "method": "herodb_lanceSearchText", + "params": [ + 1, + "products", + "comfortable shoes for running", + 5, + "category = 'footwear'", + null + ] +} +``` + +Redis-like: +```bash +LANCE.SEARCH products K 5 QUERY "comfortable shoes for running" FILTER "category = 'footwear'" +``` + +This returns only items where `category` equals `'footwear'`. + +**Step 7: List datasets** + +JSON-RPC: +```json +{ + "jsonrpc": "2.0", + "id": 10, + "method": "herodb_lanceList", + "params": [1] +} +``` + +Redis-like: +```bash +LANCE.LIST +``` + +**Step 8: Get dataset info** + +JSON-RPC: +```json +{ + "jsonrpc": "2.0", + "id": 11, + "method": "herodb_lanceInfo", + "params": [1, "products"] +} +``` + +Redis-like: +```bash +LANCE.INFO products +``` + +Returns dimension, row count, and other metadata. + +--- + +## Scenario 2: OpenAI API + +Use OpenAI's production embedding service for semantic search. + +### Setup + +**1. Set your API key:** + +```bash +export OPENAI_API_KEY="sk-your-actual-openai-key-here" +``` + +**2. Start HeroDB** (same as before): + +```bash +./target/release/herodb --dir ./data --admin-secret my-admin-secret --enable-rpc --rpc-port 8080 +``` + +### End-to-End Example with OpenAI + +**Step 1: Create a Lance database** + +JSON-RPC: +```json +{ + "jsonrpc": "2.0", + "id": 1, + "method": "herodb_createDatabase", + "params": [ + "Lance", + { "name": "openai-vectors", "storage_path": null, "max_size": null, "redis_version": null }, + null + ] +} +``` + +Expected: `{"jsonrpc":"2.0","id":1,"result":1}` (database ID = 1) + +**Step 2: Configure OpenAI embeddings** + +JSON-RPC: +```json +{ + "jsonrpc": "2.0", + "id": 2, + "method": "herodb_lanceSetEmbeddingConfig", + "params": [ + 1, + "documents", + { + "provider": "openai", + "model": "text-embedding-3-small", + "dim": 1536, + "endpoint": null, + "headers": {}, + "timeout_ms": 30000 + } + ] +} +``` + +Redis-like: +```bash +redis-cli -p 6379 +SELECT 1 +LANCE.EMBEDDING CONFIG SET documents PROVIDER openai MODEL text-embedding-3-small DIM 1536 TIMEOUTMS 30000 +``` + +Notes: +- `endpoint` is `null` (defaults to OpenAI API: https://api.openai.com/v1/embeddings) +- `headers` is empty (Authorization auto-added from OPENAI_API_KEY env var) +- `dim` is 1536 for text-embedding-3-small + +Expected: `{"jsonrpc":"2.0","id":2,"result":true}` + +**Step 3: Insert documents** + +JSON-RPC: +```json +{ + "jsonrpc": "2.0", + "id": 3, + "method": "herodb_lanceStoreText", + "params": [ + 1, + "documents", + "doc-1", + "The quick brown fox jumps over the lazy dog", + { "source": "example", "lang": "en", "topic": "animals" } + ] +} +``` + +```json +{ + "jsonrpc": "2.0", + "id": 4, + "method": "herodb_lanceStoreText", + "params": [ + 1, + "documents", + "doc-2", + "Machine learning models require large datasets for training and validation", + { "source": "tech", "lang": "en", "topic": "ai" } + ] +} +``` + +```json +{ + "jsonrpc": "2.0", + "id": 5, + "method": "herodb_lanceStoreText", + "params": [ + 1, + "documents", + "doc-3", + "Python is a popular programming language for data science and web development", + { "source": "tech", "lang": "en", "topic": "programming" } + ] +} +``` + +Redis-like: +```bash +LANCE.STORE documents ID doc-1 TEXT "The quick brown fox jumps over the lazy dog" META source example lang en topic animals +LANCE.STORE documents ID doc-2 TEXT "Machine learning models require large datasets for training and validation" META source tech lang en topic ai +LANCE.STORE documents ID doc-3 TEXT "Python is a popular programming language for data science and web development" META source tech lang en topic programming +``` + +Expected for each: `{"jsonrpc":"2.0","id":N,"result":true}` + +**Step 4: Semantic search** + +JSON-RPC: +```json +{ + "jsonrpc": "2.0", + "id": 6, + "method": "herodb_lanceSearchText", + "params": [ + 1, + "documents", + "artificial intelligence and neural networks", + 3, + null, + ["source", "topic"] + ] +} +``` + +Redis-like: +```bash +LANCE.SEARCH documents K 3 QUERY "artificial intelligence and neural networks" RETURN 2 source topic +``` + +Expected response (doc-2 should rank highest due to semantic similarity): +```json +{ + "jsonrpc": "2.0", + "id": 6, + "result": { + "results": [ + { + "id": "doc-2", + "score": 0.123, + "meta": { + "source": "tech", + "topic": "ai" + } + }, + { + "id": "doc-3", + "score": 0.456, + "meta": { + "source": "tech", + "topic": "programming" + } + }, + { + "id": "doc-1", + "score": 0.789, + "meta": { + "source": "example", + "topic": "animals" + } + } + ] + } +} +``` + +Note: Lower score = better match (L2 distance). + +**Step 5: Search with filter** + +JSON-RPC: +```json +{ + "jsonrpc": "2.0", + "id": 7, + "method": "herodb_lanceSearchText", + "params": [ + 1, + "documents", + "programming and software", + 5, + "topic = 'programming'", + null + ] +} +``` + +Redis-like: +```bash +LANCE.SEARCH documents K 5 QUERY "programming and software" FILTER "topic = 'programming'" +``` + +This returns only documents where `topic` equals `'programming'`. + +--- + +## Scenario 2: OpenAI API + +Use OpenAI's production embedding service for high-quality semantic search. + +### Setup + +**1. Set your OpenAI API key:** + +```bash +export OPENAI_API_KEY="sk-your-actual-openai-key-here" +``` + +**2. Start HeroDB:** + +```bash +./target/release/herodb --dir ./data --admin-secret my-admin-secret --enable-rpc --rpc-port 8080 +``` + +### Complete Workflow + +**Step 1: Create database** + +JSON-RPC: +```json +{ + "jsonrpc": "2.0", + "id": 1, + "method": "herodb_createDatabase", + "params": [ + "Lance", + { "name": "openai-docs", "storage_path": null, "max_size": null, "redis_version": null }, + null + ] +} +``` + +**Step 2: Configure OpenAI embeddings** + +JSON-RPC: +```json +{ + "jsonrpc": "2.0", + "id": 2, + "method": "herodb_lanceSetEmbeddingConfig", + "params": [ + 1, + "articles", + { + "provider": "openai", + "model": "text-embedding-3-small", + "dim": 1536, + "endpoint": null, + "headers": {}, + "timeout_ms": 30000 + } + ] +} +``` + +Redis-like: +```bash +SELECT 1 +LANCE.EMBEDDING CONFIG SET articles PROVIDER openai MODEL text-embedding-3-small DIM 1536 +``` + +**Step 3: Insert articles** + +JSON-RPC: +```json +{ + "jsonrpc": "2.0", + "id": 3, + "method": "herodb_lanceStoreText", + "params": [ + 1, + "articles", + "article-1", + "Climate change is affecting global weather patterns and ecosystems", + { "category": "environment", "author": "Jane Smith", "year": "2024" } + ] +} +``` + +```json +{ + "jsonrpc": "2.0", + "id": 4, + "method": "herodb_lanceStoreText", + "params": [ + 1, + "articles", + "article-2", + "Quantum computing promises to revolutionize cryptography and drug discovery", + { "category": "technology", "author": "John Doe", "year": "2024" } + ] +} +``` + +```json +{ + "jsonrpc": "2.0", + "id": 5, + "method": "herodb_lanceStoreText", + "params": [ + 1, + "articles", + "article-3", + "Renewable energy sources like solar and wind are becoming more cost-effective", + { "category": "environment", "author": "Alice Johnson", "year": "2023" } + ] +} +``` + +**Step 4: Semantic search** + +JSON-RPC: +```json +{ + "jsonrpc": "2.0", + "id": 6, + "method": "herodb_lanceSearchText", + "params": [ + 1, + "articles", + "environmental sustainability and green energy", + 2, + null, + ["category", "author"] + ] +} +``` + +Redis-like: +```bash +LANCE.SEARCH articles K 2 QUERY "environmental sustainability and green energy" RETURN 2 category author +``` + +Expected: Returns article-1 and article-3 (both environment-related). + +**Step 5: Filtered search** + +JSON-RPC: +```json +{ + "jsonrpc": "2.0", + "id": 7, + "method": "herodb_lanceSearchText", + "params": [ + 1, + "articles", + "new technology innovations", + 5, + "category = 'technology'", + null + ] +} +``` + +--- + +## Scenario 3: Deterministic Test Embedder (No Network) + +For CI/offline development, use the built-in test embedder that requires no external service. + +### Configuration + +JSON-RPC: +```json +{ + "jsonrpc": "2.0", + "id": 1, + "method": "herodb_lanceSetEmbeddingConfig", + "params": [ + 1, + "testdata", + { + "provider": "test", + "model": "dev", + "dim": 64, + "endpoint": null, + "headers": {}, + "timeout_ms": null + } + ] +} +``` + +Redis-like: +```bash +SELECT 1 +LANCE.EMBEDDING CONFIG SET testdata PROVIDER test MODEL dev DIM 64 +``` + +### Usage + +Use `lanceStoreText` and `lanceSearchText` as in previous scenarios. The embeddings are: +- Deterministic (same text → same vector) +- Fast (no network) +- Not semantic (hash-based, not ML) + +Perfect for testing the vector storage/search mechanics without external dependencies. + +--- + +## Advanced: Custom Headers and Timeouts + +### Example: Local model with custom auth + +JSON-RPC: +```json +{ + "jsonrpc": "2.0", + "id": 1, + "method": "herodb_lanceSetEmbeddingConfig", + "params": [ + 1, + "secure-data", + { + "provider": "openai", + "model": "custom-model", + "dim": 512, + "endpoint": "http://192.168.1.100:9000/embeddings", + "headers": { + "Authorization": "Bearer my-local-token", + "X-Custom-Header": "value" + }, + "timeout_ms": 60000 + } + ] +} +``` + +### Example: OpenAI with explicit API key (not from env) + +JSON-RPC: +```json +{ + "jsonrpc": "2.0", + "id": 1, + "method": "herodb_lanceSetEmbeddingConfig", + "params": [ + 1, + "dataset", + { + "provider": "openai", + "model": "text-embedding-3-small", + "dim": 1536, + "endpoint": null, + "headers": { + "Authorization": "Bearer sk-your-key-here" + }, + "timeout_ms": 30000 + } + ] +} +``` + +--- + +## Troubleshooting + +### Error: "Embedding config not set for dataset" + +**Cause:** You tried to use `lanceStoreText` or `lanceSearchText` without configuring an embedder. + +**Solution:** Run `lanceSetEmbeddingConfig` first. + +### Error: "Embedding dimension mismatch: expected X, got Y" + +**Cause:** The embedding service returned vectors of a different size than configured. + +**Solution:** +- For OpenAI text-embedding-3-small, use `dim: 1536` +- For your local mock (from this tutorial), use `dim: 768` +- Check your embedding service's actual output dimension + +### Error: "Missing API key in env 'OPENAI_API_KEY'" + +**Cause:** Using OpenAI provider without setting the API key. + +**Solution:** +- Set `export OPENAI_API_KEY="sk-..."` before starting HeroDB, OR +- Pass the key explicitly in headers: `"Authorization": "Bearer sk-..."` + +### Error: "HTTP request failed" or "Embeddings API error 404" + +**Cause:** Cannot reach the embedding endpoint. + +**Solution:** +- Verify your local server is running: `curl http://127.0.0.1:8081/v1/embeddings` +- Check the endpoint URL in your config +- Ensure firewall allows the connection + +### Error: "ERR DB backend is not Lance" + +**Cause:** Trying to use LANCE.* commands on a non-Lance database. + +**Solution:** Create the database with backend "Lance" (see Step 1). + +### Error: "write permission denied" + +**Cause:** Database is private and you haven't authenticated. + +**Solution:** Use `SELECT KEY ` or make the database public via RPC. + +--- + +## Complete Example Script (Bash + curl) + +Save as `test_embeddings.sh`: + +```bash +#!/bin/bash + +RPC_URL="http://localhost:8080" + +# 1. Create Lance database +curl -X POST $RPC_URL -H "Content-Type: application/json" -d '{ + "jsonrpc": "2.0", + "id": 1, + "method": "herodb_createDatabase", + "params": ["Lance", {"name": "test-vectors", "storage_path": null, "max_size": null, "redis_version": null}, null] +}' + +echo -e "\n" + +# 2. Configure local embedder +curl -X POST $RPC_URL -H "Content-Type: application/json" -d '{ + "jsonrpc": "2.0", + "id": 2, + "method": "herodb_lanceSetEmbeddingConfig", + "params": [1, "products", { + "provider": "openai", + "model": "mock", + "dim": 768, + "endpoint": "http://127.0.0.1:8081/v1/embeddings", + "headers": {"Authorization": "Bearer dummy"}, + "timeout_ms": 30000 + }] +}' + +echo -e "\n" + +# 3. Insert data +curl -X POST $RPC_URL -H "Content-Type: application/json" -d '{ + "jsonrpc": "2.0", + "id": 3, + "method": "herodb_lanceStoreText", + "params": [1, "products", "item-1", "Hiking boots", {"brand": "TrailMax"}] +}' + +echo -e "\n" + +# 4. Search +curl -X POST $RPC_URL -H "Content-Type: application/json" -d '{ + "jsonrpc": "2.0", + "id": 4, + "method": "herodb_lanceSearchText", + "params": [1, "products", "outdoor footwear", 5, null, null] +}' + +echo -e "\n" +``` + +Run: +```bash +chmod +x test_embeddings.sh +./test_embeddings.sh +``` + +--- + +## Summary + +| Provider | Use Case | Endpoint | API Key | +|----------|----------|----------|---------| +| `openai` | Production semantic search | Default (OpenAI) or custom URL | OPENAI_API_KEY env or headers | +| `openai` | Local self-hosted gateway | http://127.0.0.1:8081/... | Optional (depends on your service) | +| `test` | CI/offline development | N/A (local hash) | None | +| `image_test` | Image testing | N/A (local hash) | None | + +**Notes:** +- The `provider` field is always `"openai"` for OpenAI-compatible services (whether cloud or local). This is because it uses the OpenAI-compatible API shape. +- Use `endpoint` to point to your local service +- Use `headers` for custom authentication +- `dim` must match your embedding service's output dimension +- Once configured, `lanceStoreText` and `lanceSearchText` handle embedding automatically \ No newline at end of file diff --git a/mock_embedder.py b/mock_embedder.py new file mode 100644 index 0000000..1dbfca9 --- /dev/null +++ b/mock_embedder.py @@ -0,0 +1,34 @@ +from flask import Flask, request, jsonify +import numpy as np + +app = Flask(__name__) + +@app.route('/v1/embeddings', methods=['POST']) +def embeddings(): + data = request.json + inputs = data.get('input', []) + if isinstance(inputs, str): + inputs = [inputs] + + # Generate deterministic 768-dim embeddings (hash-based) + embeddings = [] + for text in inputs: + # Simple hash to vector + vec = np.zeros(768) + for i, char in enumerate(text[:768]): + vec[i % 768] += ord(char) / 255.0 + # Normalize + norm = np.linalg.norm(vec) + if norm > 0: + vec = vec / norm + embeddings.append(vec.tolist()) + + return jsonify({ + "data": [{"embedding": emb} for emb in embeddings], + "model": data.get('model', 'mock'), + "usage": {"total_tokens": sum(len(t) for t in inputs)} + }) + +if __name__ == '__main__': + app.run(host='127.0.0.1', port=8081) + diff --git a/src/cmd.rs b/src/cmd.rs index 88884da..a9e9e1e 100644 --- a/src/cmd.rs +++ b/src/cmd.rs @@ -1,4 +1,4 @@ -use crate::{error::DBError, protocol::Protocol, server::Server, embedding::{EmbeddingConfig, EmbeddingProvider}}; +use crate::{error::DBError, protocol::Protocol, server::Server, embedding::EmbeddingConfig}; use base64::{engine::general_purpose, Engine as _}; use tokio::time::{timeout, Duration}; use futures::future::select_all; @@ -170,9 +170,7 @@ pub enum Cmd { // Embedding configuration per dataset LanceEmbeddingConfigSet { name: String, - provider: String, - model: String, - params: Vec<(String, String)>, + config: EmbeddingConfig, }, LanceEmbeddingConfigGet { name: String, @@ -1089,20 +1087,25 @@ impl Cmd { Cmd::LanceCreateIndex { name, index_type, params } } "lance.embedding" => { - // LANCE.EMBEDDING CONFIG SET name PROVIDER p MODEL m [PARAM k v ...] + // LANCE.EMBEDDING CONFIG SET name PROVIDER p MODEL m DIM d [ENDPOINT url] [HEADER k v]... [TIMEOUTMS t] // LANCE.EMBEDDING CONFIG GET name if cmd.len() < 3 || cmd[1].to_uppercase() != "CONFIG" { return Err(DBError("ERR LANCE.EMBEDDING requires CONFIG subcommand".to_string())); } if cmd.len() >= 4 && cmd[2].to_uppercase() == "SET" { if cmd.len() < 8 { - return Err(DBError("ERR LANCE.EMBEDDING CONFIG SET requires: SET name PROVIDER p MODEL m [PARAM k v ...]".to_string())); + return Err(DBError("ERR LANCE.EMBEDDING CONFIG SET requires: SET name PROVIDER p MODEL m DIM d [ENDPOINT url] [HEADER k v]... [TIMEOUTMS t]".to_string())); } let name = cmd[3].clone(); let mut i = 4; + let mut provider: Option = None; let mut model: Option = None; - let mut params: Vec<(String, String)> = Vec::new(); + let mut dim: Option = None; + let mut endpoint: Option = None; + let mut headers: std::collections::HashMap = std::collections::HashMap::new(); + let mut timeout_ms: Option = None; + while i < cmd.len() { match cmd[i].to_uppercase().as_str() { "PROVIDER" => { @@ -1119,22 +1122,62 @@ impl Cmd { model = Some(cmd[i + 1].clone()); i += 2; } - "PARAM" => { - i += 1; - while i + 1 < cmd.len() { - params.push((cmd[i].clone(), cmd[i + 1].clone())); - i += 2; + "DIM" => { + if i + 1 >= cmd.len() { + return Err(DBError("ERR DIM requires a value".to_string())); } + let d: usize = cmd[i + 1].parse().map_err(|_| DBError("ERR DIM must be an integer".to_string()))?; + dim = Some(d); + i += 2; + } + "ENDPOINT" => { + if i + 1 >= cmd.len() { + return Err(DBError("ERR ENDPOINT requires a value".to_string())); + } + endpoint = Some(cmd[i + 1].clone()); + i += 2; + } + "HEADER" => { + if i + 2 >= cmd.len() { + return Err(DBError("ERR HEADER requires key and value".to_string())); + } + headers.insert(cmd[i + 1].clone(), cmd[i + 2].clone()); + i += 3; + } + "TIMEOUTMS" => { + if i + 1 >= cmd.len() { + return Err(DBError("ERR TIMEOUTMS requires a value".to_string())); + } + let t: u64 = cmd[i + 1].parse().map_err(|_| DBError("ERR TIMEOUTMS must be an integer".to_string()))?; + timeout_ms = Some(t); + i += 2; } _ => { - // Unknown token; break to avoid infinite loop i += 1; } } } - let provider = provider.ok_or_else(|| DBError("ERR missing PROVIDER".to_string()))?; + + let provider_str = provider.ok_or_else(|| DBError("ERR missing PROVIDER".to_string()))?; + let provider_enum = match provider_str.to_lowercase().as_str() { + "openai" => crate::embedding::EmbeddingProvider::openai, + "test" => crate::embedding::EmbeddingProvider::test, + "image_test" | "imagetest" | "image-test" => crate::embedding::EmbeddingProvider::image_test, + other => return Err(DBError(format!("ERR unsupported provider '{}'", other))), + }; let model = model.ok_or_else(|| DBError("ERR missing MODEL".to_string()))?; - Cmd::LanceEmbeddingConfigSet { name, provider, model, params } + let dim = dim.ok_or_else(|| DBError("ERR missing DIM".to_string()))?; + + let config = EmbeddingConfig { + provider: provider_enum, + model, + dim, + endpoint, + headers, + timeout_ms, + }; + + Cmd::LanceEmbeddingConfigSet { name, config } } else if cmd.len() == 4 && cmd[2].to_uppercase() == "GET" { let name = cmd[3].clone(); Cmd::LanceEmbeddingConfigGet { name } @@ -1437,25 +1480,14 @@ impl Cmd { Err(e) => Ok(Protocol::err(&e.0)), } } - Cmd::LanceEmbeddingConfigSet { name, provider, model, params } => { + Cmd::LanceEmbeddingConfigSet { name, config } => { if !server.has_write_permission() { return Ok(Protocol::err("ERR write permission denied")); } - // Map provider string to enum - let p_lc = provider.to_lowercase(); - let prov = match p_lc.as_str() { - "test-hash" | "testhash" => EmbeddingProvider::TestHash, - "testimagehash" | "image-test-hash" | "imagetesthash" => EmbeddingProvider::ImageTestHash, - "fastembed" | "lancefastembed" => EmbeddingProvider::LanceFastEmbed, - "openai" | "lanceopenai" => EmbeddingProvider::LanceOpenAI, - other => EmbeddingProvider::LanceOther(other.to_string()), - }; - let cfg = EmbeddingConfig { - provider: prov, - model, - params: params.into_iter().collect(), - }; - match server.set_dataset_embedding_config(&name, &cfg) { + if config.dim == 0 { + return Ok(Protocol::err("ERR embedding DIM must be > 0")); + } + match server.set_dataset_embedding_config(&name, &config) { Ok(()) => Ok(Protocol::SimpleString("OK".to_string())), Err(e) => Ok(Protocol::err(&e.0)), } @@ -1466,16 +1498,20 @@ impl Cmd { let mut arr = Vec::new(); arr.push(Protocol::BulkString("provider".to_string())); arr.push(Protocol::BulkString(match cfg.provider { - EmbeddingProvider::TestHash => "test-hash".to_string(), - EmbeddingProvider::ImageTestHash => "testimagehash".to_string(), - EmbeddingProvider::LanceFastEmbed => "lancefastembed".to_string(), - EmbeddingProvider::LanceOpenAI => "lanceopenai".to_string(), - EmbeddingProvider::LanceOther(ref s) => s.clone(), + crate::embedding::EmbeddingProvider::openai => "openai".to_string(), + crate::embedding::EmbeddingProvider::test => "test".to_string(), + crate::embedding::EmbeddingProvider::image_test => "image_test".to_string(), })); arr.push(Protocol::BulkString("model".to_string())); arr.push(Protocol::BulkString(cfg.model.clone())); - arr.push(Protocol::BulkString("params".to_string())); - arr.push(Protocol::BulkString(serde_json::to_string(&cfg.params).unwrap_or_else(|_| "{}".to_string()))); + arr.push(Protocol::BulkString("dim".to_string())); + arr.push(Protocol::BulkString(cfg.dim.to_string())); + arr.push(Protocol::BulkString("endpoint".to_string())); + arr.push(Protocol::BulkString(cfg.endpoint.clone().unwrap_or_default())); + arr.push(Protocol::BulkString("timeout_ms".to_string())); + arr.push(Protocol::BulkString(cfg.timeout_ms.map(|v| v.to_string()).unwrap_or_default())); + arr.push(Protocol::BulkString("headers".to_string())); + arr.push(Protocol::BulkString(serde_json::to_string(&cfg.headers).unwrap_or_else(|_| "{}".to_string()))); Ok(Protocol::Array(arr)) } Err(e) => Ok(Protocol::err(&e.0)), diff --git a/src/embedding.rs b/src/embedding.rs index 79690d4..06c7373 100644 --- a/src/embedding.rs +++ b/src/embedding.rs @@ -1,4 +1,4 @@ -// Embedding abstraction and minimal providers. +// Embedding abstraction with a single external provider (OpenAI-compatible) and local test providers. use std::collections::HashMap; use std::sync::Arc; @@ -7,42 +7,41 @@ use serde::{Deserialize, Serialize}; use crate::error::DBError; -// Networking for OpenAI/Azure +// Networking for OpenAI-compatible endpoints use std::time::Duration; use ureq::{Agent, AgentBuilder}; use serde_json::json; -/// Provider identifiers. Extend as needed to mirror LanceDB-supported providers. +/// Provider identifiers (minimal set). #[derive(Debug, Clone, Serialize, Deserialize, PartialEq, Eq)] #[serde(rename_all = "snake_case")] pub enum EmbeddingProvider { - // Deterministic, local-only embedder for CI and offline development (text). - TestHash, - // Deterministic, local-only embedder for CI and offline development (image). - ImageTestHash, - // Placeholders for LanceDB-supported providers; implementers can add concrete backends later. - LanceFastEmbed, - LanceOpenAI, - LanceOther(String), + /// External HTTP provider compatible with OpenAI's embeddings API. + openai, + /// Deterministic, local-only embedder for CI and offline development (text). + test, + /// Deterministic, local-only embedder for CI and offline development (image). + image_test, } /// Serializable embedding configuration. -/// params: arbitrary key-value map for provider-specific knobs (e.g., "dim", "api_key_env", etc.) +/// - provider: "openai" | "test" | "image_test" +/// - model: provider/model id (e.g., "text-embedding-3-small"), may be ignored by local gateways +/// - dim: required output dimension (used to create Lance datasets and validate outputs) +/// - endpoint: optional HTTP endpoint (defaults to OpenAI API when provider == openai) +/// - headers: optional HTTP headers (e.g., Authorization). If empty and OPENAI_API_KEY is present, Authorization will be inferred. +/// - timeout_ms: optional HTTP timeout in milliseconds (for both read and write) #[derive(Debug, Clone, Serialize, Deserialize)] pub struct EmbeddingConfig { pub provider: EmbeddingProvider, pub model: String, + pub dim: usize, #[serde(default)] - pub params: HashMap, -} - -impl EmbeddingConfig { - pub fn get_param_usize(&self, key: &str) -> Option { - self.params.get(key).and_then(|v| v.parse::().ok()) - } - pub fn get_param_string(&self, key: &str) -> Option { - self.params.get(key).cloned() - } + pub endpoint: Option, + #[serde(default)] + pub headers: HashMap, + #[serde(default)] + pub timeout_ms: Option, } /// A provider-agnostic text embedding interface. @@ -59,6 +58,20 @@ pub trait Embedder: Send + Sync { } } +/// Image embedding interface (separate from text to keep modality-specific inputs). +pub trait ImageEmbedder: Send + Sync { + /// Human-readable provider/model name + fn name(&self) -> String; + /// Embedding dimension + fn dim(&self) -> usize; + /// Embed a single image (raw bytes) + fn embed_image(&self, bytes: &[u8]) -> Result, DBError>; + /// Embed many images; default maps embed_image() over inputs + fn embed_many_images(&self, images: &[Vec]) -> Result>, DBError> { + images.iter().map(|b| self.embed_image(b)).collect() + } +} + //// ----------------------------- TEXT: deterministic test embedder ----------------------------- /// Deterministic, no-deps, no-network embedder for CI and offline dev. @@ -88,7 +101,7 @@ impl TestHashEmbedder { impl Embedder for TestHashEmbedder { fn name(&self) -> String { - format!("test-hash:{}", self.model_name) + format!("test:{}", self.model_name) } fn dim(&self) -> usize { @@ -117,21 +130,7 @@ impl Embedder for TestHashEmbedder { } } -//// ----------------------------- IMAGE: trait + deterministic test embedder ----------------------------- - -/// Image embedding interface (separate from text to keep modality-specific inputs). -pub trait ImageEmbedder: Send + Sync { - /// Human-readable provider/model name - fn name(&self) -> String; - /// Embedding dimension - fn dim(&self) -> usize; - /// Embed a single image (raw bytes) - fn embed_image(&self, bytes: &[u8]) -> Result, DBError>; - /// Embed many images; default maps embed_image() over inputs - fn embed_many_images(&self, images: &[Vec]) -> Result>, DBError> { - images.iter().map(|b| self.embed_image(b)).collect() - } -} +//// ----------------------------- IMAGE: deterministic test embedder ----------------------------- /// Deterministic image embedder that folds bytes into buckets, applies tanh-like nonlinearity, /// and L2-normalizes. Suitable for CI and offline development. @@ -159,7 +158,7 @@ impl TestImageHashEmbedder { impl ImageEmbedder for TestImageHashEmbedder { fn name(&self) -> String { - format!("test-image-hash:{}", self.model_name) + format!("image_test:{}", self.model_name) } fn dim(&self) -> usize { @@ -188,80 +187,45 @@ impl ImageEmbedder for TestImageHashEmbedder { } } -//// OpenAI embedder (supports OpenAI and Azure OpenAI via REST) +//// ----------------------------- OpenAI-compatible HTTP embedder ----------------------------- + struct OpenAIEmbedder { model: String, dim: usize, agent: Agent, endpoint: String, headers: Vec<(String, String)>, - use_azure: bool, } impl OpenAIEmbedder { fn new_from_config(cfg: &EmbeddingConfig) -> Result { - // Whether to use Azure OpenAI - let use_azure = cfg - .get_param_string("use_azure") - .map(|s| s.eq_ignore_ascii_case("true")) - .unwrap_or(false); - - // Resolve API key (OPENAI_API_KEY or AZURE_OPENAI_API_KEY by default) - let api_key_env = cfg - .get_param_string("api_key_env") - .unwrap_or_else(|| { - if use_azure { - "AZURE_OPENAI_API_KEY".to_string() - } else { - "OPENAI_API_KEY".to_string() - } - }); - let api_key = std::env::var(&api_key_env) - .map_err(|_| DBError(format!("Missing API key in env '{}'", api_key_env)))?; - // Resolve endpoint - // - Standard OpenAI: https://api.openai.com/v1/embeddings (default) or params["base_url"] - // - Azure OpenAI: {azure_endpoint}/openai/deployments/{deployment}/embeddings?api-version=... - let endpoint = if use_azure { - let base = cfg - .get_param_string("azure_endpoint") - .ok_or_else(|| DBError("Missing 'azure_endpoint' for Azure OpenAI".into()))?; - let deployment = cfg - .get_param_string("azure_deployment") - .unwrap_or_else(|| cfg.model.clone()); - let api_version = cfg - .get_param_string("azure_api_version") - .unwrap_or_else(|| "2023-05-15".to_string()); - format!( - "{}/openai/deployments/{}/embeddings?api-version={}", - base.trim_end_matches('/'), - deployment, - api_version - ) - } else { - cfg.get_param_string("base_url") - .unwrap_or_else(|| "https://api.openai.com/v1/embeddings".to_string()) - }; + let endpoint = cfg.endpoint.clone().unwrap_or_else(|| { + "https://api.openai.com/v1/embeddings".to_string() + }); - // Determine expected dimension (default 1536 for text-embedding-3-small; callers should override if needed) - let dim = cfg - .get_param_usize("dim") - .or_else(|| cfg.get_param_usize("dimensions")) - .unwrap_or(1536); + // Determine expected dimension (required by config) + let dim = cfg.dim; // Build an HTTP agent with timeouts (blocking; no tokio runtime involved) + let to_ms = cfg.timeout_ms.unwrap_or(30_000); let agent = AgentBuilder::new() - .timeout_read(Duration::from_secs(30)) - .timeout_write(Duration::from_secs(30)) + .timeout_read(Duration::from_millis(to_ms)) + .timeout_write(Duration::from_millis(to_ms)) .build(); - // Headers - let mut headers: Vec<(String, String)> = Vec::new(); - headers.push(("Content-Type".to_string(), "application/json".to_string())); - if use_azure { - headers.push(("api-key".to_string(), api_key)); - } else { - headers.push(("Authorization".to_string(), format!("Bearer {}", api_key))); + // Headers: start from cfg.headers, and add Authorization from env if absent and available + let mut headers: Vec<(String, String)> = + cfg.headers.iter().map(|(k, v)| (k.clone(), v.clone())).collect(); + + if !headers.iter().any(|(k, _)| k.eq_ignore_ascii_case("content-type")) { + headers.push(("Content-Type".to_string(), "application/json".to_string())); + } + + if !headers.iter().any(|(k, _)| k.eq_ignore_ascii_case("authorization")) { + if let Ok(key) = std::env::var("OPENAI_API_KEY") { + headers.push(("Authorization".to_string(), format!("Bearer {}", key))); + } } Ok(Self { @@ -270,19 +234,12 @@ impl OpenAIEmbedder { agent, endpoint, headers, - use_azure, }) } fn request_many(&self, inputs: &[String]) -> Result>, DBError> { - // Compose request body: - // - Standard OpenAI: { "model": ..., "input": [...], "dimensions": dim? } - // - Azure: { "input": [...], "dimensions": dim? } (model from deployment) - let mut body = if self.use_azure { - json!({ "input": inputs }) - } else { - json!({ "model": self.model, "input": inputs }) - }; + // Compose request body (OpenAI-compatible) + let mut body = json!({ "model": self.model, "input": inputs }); if self.dim > 0 { body.as_object_mut() .unwrap() @@ -331,7 +288,7 @@ impl OpenAIEmbedder { } if self.dim > 0 && v.len() != self.dim { return Err(DBError(format!( - "Embedding dimension mismatch: expected {}, got {}. Configure 'dim' or 'dimensions' to match output.", + "Embedding dimension mismatch: expected {}, got {}. Configure 'dim' to match output.", self.dim, v.len() ))); } @@ -343,11 +300,7 @@ impl OpenAIEmbedder { impl Embedder for OpenAIEmbedder { fn name(&self) -> String { - if self.use_azure { - format!("azure-openai:{}", self.model) - } else { - format!("openai:{}", self.model) - } + format!("openai:{}", self.model) } fn dim(&self) -> usize { @@ -368,38 +321,33 @@ impl Embedder for OpenAIEmbedder { } /// Create an embedder instance from a config. -/// - TestHash: uses params["dim"] or defaults to 64 -/// - LanceOpenAI: uses OpenAI (or Azure OpenAI) embeddings REST API -/// - Other Lance providers can be added similarly +/// - openai: uses OpenAI-compatible embeddings REST API (endpoint override supported) +/// - test: deterministic local text embedder (no network) +/// - image_test: not valid for text (use create_image_embedder) pub fn create_embedder(config: &EmbeddingConfig) -> Result, DBError> { match &config.provider { - EmbeddingProvider::TestHash => { - let dim = config.get_param_usize("dim").unwrap_or(64); - Ok(Arc::new(TestHashEmbedder::new(dim, config.model.clone()))) - } - EmbeddingProvider::LanceOpenAI => { + EmbeddingProvider::openai => { let inner = OpenAIEmbedder::new_from_config(config)?; Ok(Arc::new(inner)) } - EmbeddingProvider::ImageTestHash => { + EmbeddingProvider::test => { + Ok(Arc::new(TestHashEmbedder::new(config.dim, config.model.clone()))) + } + EmbeddingProvider::image_test => { Err(DBError("Use create_image_embedder() for image providers".into())) } - EmbeddingProvider::LanceFastEmbed => Err(DBError("LanceFastEmbed provider not yet implemented in Rust embedding layer; configure 'test-hash' or use 'openai'".into())), - EmbeddingProvider::LanceOther(p) => Err(DBError(format!("Lance provider '{}' not implemented; configure 'openai' or 'test-hash'", p))), } } /// Create an image embedder instance from a config. +/// - image_test: deterministic local image embedder pub fn create_image_embedder(config: &EmbeddingConfig) -> Result, DBError> { match &config.provider { - EmbeddingProvider::ImageTestHash => { - let dim = config.get_param_usize("dim").unwrap_or(512); - Ok(Arc::new(TestImageHashEmbedder::new(dim, config.model.clone()))) + EmbeddingProvider::image_test => { + Ok(Arc::new(TestImageHashEmbedder::new(config.dim, config.model.clone()))) } - EmbeddingProvider::TestHash | EmbeddingProvider::LanceOpenAI => { - Err(DBError("Configured text provider; dataset expects image provider (e.g., 'testimagehash')".into())) + EmbeddingProvider::test | EmbeddingProvider::openai => { + Err(DBError("Configured text provider; dataset expects image provider (e.g., 'image_test')".into())) } - EmbeddingProvider::LanceFastEmbed => Err(DBError("Image provider 'lancefastembed' not yet implemented".into())), - EmbeddingProvider::LanceOther(p) => Err(DBError(format!("Image provider '{}' not implemented; use 'testimagehash' for now", p))), } } \ No newline at end of file diff --git a/src/rpc.rs b/src/rpc.rs index 609014f..90d6d06 100644 --- a/src/rpc.rs +++ b/src/rpc.rs @@ -9,7 +9,7 @@ use sha2::{Digest, Sha256}; use crate::server::Server; use crate::options::DBOption; use crate::admin_meta; -use crate::embedding::{EmbeddingConfig, EmbeddingProvider}; +use crate::embedding::EmbeddingConfig; use base64::{engine::general_purpose, Engine as _}; /// Database backend types @@ -248,9 +248,7 @@ pub trait Rpc { &self, db_id: u64, name: String, - provider: String, - model: String, - params: Option>, + config: EmbeddingConfig, ) -> RpcResult; /// Get per-dataset embedding configuration @@ -1008,9 +1006,7 @@ impl RpcServer for RpcServerImpl { &self, db_id: u64, name: String, - provider: String, - model: String, - params: Option>, + config: EmbeddingConfig, ) -> RpcResult { let server = self.get_or_create_server(db_id).await?; if db_id == 0 { @@ -1022,19 +1018,17 @@ impl RpcServer for RpcServerImpl { if !server.has_write_permission() { return Err(jsonrpsee::types::ErrorObjectOwned::owned(-32000, "write permission denied", None::<()>)); } - let prov = match provider.to_lowercase().as_str() { - "test-hash" | "testhash" => EmbeddingProvider::TestHash, - "testimagehash" | "image-test-hash" | "imagetesthash" => EmbeddingProvider::ImageTestHash, - "fastembed" | "lancefastembed" => EmbeddingProvider::LanceFastEmbed, - "openai" | "lanceopenai" => EmbeddingProvider::LanceOpenAI, - other => EmbeddingProvider::LanceOther(other.to_string()), - }; - let cfg = EmbeddingConfig { - provider: prov, - model, - params: params.unwrap_or_default(), - }; - server.set_dataset_embedding_config(&name, &cfg) + // Validate provider and dimension (only a minimal set is allowed for now) + match config.provider { + crate::embedding::EmbeddingProvider::openai + | crate::embedding::EmbeddingProvider::test + | crate::embedding::EmbeddingProvider::image_test => {} + } + if config.dim == 0 { + return Err(jsonrpsee::types::ErrorObjectOwned::owned(-32000, "Invalid embedding config: dim must be > 0", None::<()>)); + } + + server.set_dataset_embedding_config(&name, &config) .map_err(|e| jsonrpsee::types::ErrorObjectOwned::owned(-32000, e.0, None::<()>))?; Ok(true) } @@ -1056,17 +1050,7 @@ impl RpcServer for RpcServerImpl { } let cfg = server.get_dataset_embedding_config(&name) .map_err(|e| jsonrpsee::types::ErrorObjectOwned::owned(-32000, e.0, None::<()>))?; - Ok(serde_json::json!({ - "provider": match cfg.provider { - EmbeddingProvider::TestHash => "test-hash", - EmbeddingProvider::ImageTestHash => "testimagehash", - EmbeddingProvider::LanceFastEmbed => "lancefastembed", - EmbeddingProvider::LanceOpenAI => "lanceopenai", - EmbeddingProvider::LanceOther(ref s) => s, - }, - "model": cfg.model, - "params": cfg.params - })) + Ok(serde_json::to_value(&cfg).unwrap_or(serde_json::json!({}))) } async fn lance_store_text(