fixed a few bugs related to vector embedding + added additional end to end documentation to showcase local and external embedders step-by-step + added example mock embedder python script

This commit is contained in:
Maxime Van Hees
2025-10-16 15:30:45 +02:00
parent a8720c06db
commit df780e20a2
5 changed files with 1188 additions and 198 deletions

View File

@@ -0,0 +1,988 @@
# HeroDB Embedding Models: Complete Tutorial
This tutorial demonstrates how to use embedding models with HeroDB for vector search, covering both local self-hosted models and OpenAI's API.
## Table of Contents
- [Prerequisites](#prerequisites)
- [Scenario 1: Local Embedding Model](#scenario-1-local-embedding-model-testing)
- [Scenario 2: OpenAI API](#scenario-2-openai-api)
- [Scenario 3: Deterministic Test Embedder](#scenario-3-deterministic-test-embedder-no-network)
- [Troubleshooting](#troubleshooting)
---
## Prerequisites
### Start HeroDB Server
Build and start HeroDB with RPC enabled:
```bash
cargo build --release
./target/release/herodb --dir ./data --admin-secret my-admin-secret --enable-rpc --rpc-port 8080
```
This starts:
- Redis-compatible server on port 6379
- JSON-RPC server on port 8080
### Client Tools
For Redis-like commands:
```bash
redis-cli -p 6379
```
For JSON-RPC calls, use `curl`:
```bash
curl -X POST http://localhost:8080 \
-H "Content-Type: application/json" \
-d '{"jsonrpc":"2.0","id":1,"method":"herodb_METHOD","params":[...]}'
```
---
## Scenario 1: Local Embedding Model (Testing)
Run your own embedding service locally for development, testing, or privacy.
### Option A: Python Mock Server (Simplest)
This creates a minimal OpenAI-compatible embedding server for testing.
**1. Create `mock_embedder.py`:**
```python
from flask import Flask, request, jsonify
import numpy as np
app = Flask(__name__)
@app.route('/v1/embeddings', methods=['POST'])
def embeddings():
"""OpenAI-compatible embeddings endpoint"""
data = request.json
inputs = data.get('input', [])
# Handle both single string and array
if isinstance(inputs, str):
inputs = [inputs]
# Generate deterministic 768-dim embeddings (hash-based)
embeddings = []
for text in inputs:
# Simple hash to vector (deterministic)
vec = np.zeros(768)
for i, char in enumerate(text[:768]):
vec[i % 768] += ord(char) / 255.0
# L2 normalize
norm = np.linalg.norm(vec)
if norm > 0:
vec = vec / norm
embeddings.append(vec.tolist())
return jsonify({
"data": [{"embedding": emb, "index": i} for i, emb in enumerate(embeddings)],
"model": data.get('model', 'mock-local'),
"usage": {"total_tokens": sum(len(t) for t in inputs)}
})
if __name__ == '__main__':
print("Starting mock embedding server on http://127.0.0.1:8081")
app.run(host='127.0.0.1', port=8081, debug=False)
```
**2. Install dependencies and run:**
```bash
pip install flask numpy
python mock_embedder.py
```
Output: `Starting mock embedding server on http://127.0.0.1:8081`
**3. Test the server (optional):**
```bash
curl -X POST http://127.0.0.1:8081/v1/embeddings \
-H "Content-Type: application/json" \
-d '{"input":["hello world"],"model":"test"}'
```
You should see a JSON response with a 768-dimensional embedding.
### End-to-End Example with Local Model
**Step 1: Create a Lance database**
JSON-RPC:
```json
{
"jsonrpc": "2.0",
"id": 1,
"method": "herodb_createDatabase",
"params": [
"Lance",
{ "name": "local-vectors", "storage_path": null, "max_size": null, "redis_version": null },
null
]
}
```
Expected response:
```json
{"jsonrpc":"2.0","id":1,"result":1}
```
The database ID is `1`.
**Step 2: Configure embedding for the dataset**
JSON-RPC:
```json
{
"jsonrpc": "2.0",
"id": 2,
"method": "herodb_lanceSetEmbeddingConfig",
"params": [
1,
"products",
{
"provider": "openai",
"model": "mock-local",
"dim": 768,
"endpoint": "http://127.0.0.1:8081/v1/embeddings",
"headers": {
"Authorization": "Bearer dummy"
},
"timeout_ms": 30000
}
]
}
```
Redis-like:
```bash
redis-cli -p 6379
SELECT 1
LANCE.EMBEDDING CONFIG SET products PROVIDER openai MODEL mock-local DIM 768 ENDPOINT http://127.0.0.1:8081/v1/embeddings HEADER Authorization "Bearer dummy" TIMEOUTMS 30000
```
Expected response:
```json
{"jsonrpc":"2.0","id":2,"result":true}
```
**Step 3: Verify configuration**
JSON-RPC:
```json
{
"jsonrpc": "2.0",
"id": 3,
"method": "herodb_lanceGetEmbeddingConfig",
"params": [1, "products"]
}
```
Redis-like:
```bash
LANCE.EMBEDDING CONFIG GET products
```
Expected: Returns your configuration with provider, model, dim, endpoint, etc.
**Step 4: Insert product data**
JSON-RPC (item 1):
```json
{
"jsonrpc": "2.0",
"id": 4,
"method": "herodb_lanceStoreText",
"params": [
1,
"products",
"item-1",
"Waterproof hiking boots with ankle support and aggressive tread",
{ "brand": "TrailMax", "category": "footwear", "price": "129.99" }
]
}
```
Redis-like:
```bash
LANCE.STORE products ID item-1 TEXT "Waterproof hiking boots with ankle support and aggressive tread" META brand TrailMax category footwear price 129.99
```
JSON-RPC (item 2):
```json
{
"jsonrpc": "2.0",
"id": 5,
"method": "herodb_lanceStoreText",
"params": [
1,
"products",
"item-2",
"Lightweight running shoes with breathable mesh upper",
{ "brand": "SpeedFit", "category": "footwear", "price": "89.99" }
]
}
```
JSON-RPC (item 3):
```json
{
"jsonrpc": "2.0",
"id": 6,
"method": "herodb_lanceStoreText",
"params": [
1,
"products",
"item-3",
"Insulated winter jacket with removable hood and multiple pockets",
{ "brand": "WarmTech", "category": "outerwear", "price": "199.99" }
]
}
```
JSON-RPC (item 4):
```json
{
"jsonrpc": "2.0",
"id": 7,
"method": "herodb_lanceStoreText",
"params": [
1,
"products",
"item-4",
"Camping tent for 4 people with waterproof rainfly",
{ "brand": "OutdoorPro", "category": "camping", "price": "249.99" }
]
}
```
Expected response for each: `{"jsonrpc":"2.0","id":N,"result":true}`
**Step 5: Search by text query**
JSON-RPC:
```json
{
"jsonrpc": "2.0",
"id": 8,
"method": "herodb_lanceSearchText",
"params": [
1,
"products",
"boots for hiking in wet conditions",
3,
null,
["brand", "category", "price"]
]
}
```
Redis-like:
```bash
LANCE.SEARCH products K 3 QUERY "boots for hiking in wet conditions" RETURN 3 brand category price
```
Expected response:
```json
{
"jsonrpc": "2.0",
"id": 8,
"result": {
"results": [
{
"id": "item-1",
"score": 0.234,
"meta": {
"brand": "TrailMax",
"category": "footwear",
"price": "129.99"
}
},
...
]
}
}
```
**Step 6: Search with metadata filter**
JSON-RPC:
```json
{
"jsonrpc": "2.0",
"id": 9,
"method": "herodb_lanceSearchText",
"params": [
1,
"products",
"comfortable shoes for running",
5,
"category = 'footwear'",
null
]
}
```
Redis-like:
```bash
LANCE.SEARCH products K 5 QUERY "comfortable shoes for running" FILTER "category = 'footwear'"
```
This returns only items where `category` equals `'footwear'`.
**Step 7: List datasets**
JSON-RPC:
```json
{
"jsonrpc": "2.0",
"id": 10,
"method": "herodb_lanceList",
"params": [1]
}
```
Redis-like:
```bash
LANCE.LIST
```
**Step 8: Get dataset info**
JSON-RPC:
```json
{
"jsonrpc": "2.0",
"id": 11,
"method": "herodb_lanceInfo",
"params": [1, "products"]
}
```
Redis-like:
```bash
LANCE.INFO products
```
Returns dimension, row count, and other metadata.
---
## Scenario 2: OpenAI API
Use OpenAI's production embedding service for semantic search.
### Setup
**1. Set your API key:**
```bash
export OPENAI_API_KEY="sk-your-actual-openai-key-here"
```
**2. Start HeroDB** (same as before):
```bash
./target/release/herodb --dir ./data --admin-secret my-admin-secret --enable-rpc --rpc-port 8080
```
### End-to-End Example with OpenAI
**Step 1: Create a Lance database**
JSON-RPC:
```json
{
"jsonrpc": "2.0",
"id": 1,
"method": "herodb_createDatabase",
"params": [
"Lance",
{ "name": "openai-vectors", "storage_path": null, "max_size": null, "redis_version": null },
null
]
}
```
Expected: `{"jsonrpc":"2.0","id":1,"result":1}` (database ID = 1)
**Step 2: Configure OpenAI embeddings**
JSON-RPC:
```json
{
"jsonrpc": "2.0",
"id": 2,
"method": "herodb_lanceSetEmbeddingConfig",
"params": [
1,
"documents",
{
"provider": "openai",
"model": "text-embedding-3-small",
"dim": 1536,
"endpoint": null,
"headers": {},
"timeout_ms": 30000
}
]
}
```
Redis-like:
```bash
redis-cli -p 6379
SELECT 1
LANCE.EMBEDDING CONFIG SET documents PROVIDER openai MODEL text-embedding-3-small DIM 1536 TIMEOUTMS 30000
```
Notes:
- `endpoint` is `null` (defaults to OpenAI API: https://api.openai.com/v1/embeddings)
- `headers` is empty (Authorization auto-added from OPENAI_API_KEY env var)
- `dim` is 1536 for text-embedding-3-small
Expected: `{"jsonrpc":"2.0","id":2,"result":true}`
**Step 3: Insert documents**
JSON-RPC:
```json
{
"jsonrpc": "2.0",
"id": 3,
"method": "herodb_lanceStoreText",
"params": [
1,
"documents",
"doc-1",
"The quick brown fox jumps over the lazy dog",
{ "source": "example", "lang": "en", "topic": "animals" }
]
}
```
```json
{
"jsonrpc": "2.0",
"id": 4,
"method": "herodb_lanceStoreText",
"params": [
1,
"documents",
"doc-2",
"Machine learning models require large datasets for training and validation",
{ "source": "tech", "lang": "en", "topic": "ai" }
]
}
```
```json
{
"jsonrpc": "2.0",
"id": 5,
"method": "herodb_lanceStoreText",
"params": [
1,
"documents",
"doc-3",
"Python is a popular programming language for data science and web development",
{ "source": "tech", "lang": "en", "topic": "programming" }
]
}
```
Redis-like:
```bash
LANCE.STORE documents ID doc-1 TEXT "The quick brown fox jumps over the lazy dog" META source example lang en topic animals
LANCE.STORE documents ID doc-2 TEXT "Machine learning models require large datasets for training and validation" META source tech lang en topic ai
LANCE.STORE documents ID doc-3 TEXT "Python is a popular programming language for data science and web development" META source tech lang en topic programming
```
Expected for each: `{"jsonrpc":"2.0","id":N,"result":true}`
**Step 4: Semantic search**
JSON-RPC:
```json
{
"jsonrpc": "2.0",
"id": 6,
"method": "herodb_lanceSearchText",
"params": [
1,
"documents",
"artificial intelligence and neural networks",
3,
null,
["source", "topic"]
]
}
```
Redis-like:
```bash
LANCE.SEARCH documents K 3 QUERY "artificial intelligence and neural networks" RETURN 2 source topic
```
Expected response (doc-2 should rank highest due to semantic similarity):
```json
{
"jsonrpc": "2.0",
"id": 6,
"result": {
"results": [
{
"id": "doc-2",
"score": 0.123,
"meta": {
"source": "tech",
"topic": "ai"
}
},
{
"id": "doc-3",
"score": 0.456,
"meta": {
"source": "tech",
"topic": "programming"
}
},
{
"id": "doc-1",
"score": 0.789,
"meta": {
"source": "example",
"topic": "animals"
}
}
]
}
}
```
Note: Lower score = better match (L2 distance).
**Step 5: Search with filter**
JSON-RPC:
```json
{
"jsonrpc": "2.0",
"id": 7,
"method": "herodb_lanceSearchText",
"params": [
1,
"documents",
"programming and software",
5,
"topic = 'programming'",
null
]
}
```
Redis-like:
```bash
LANCE.SEARCH documents K 5 QUERY "programming and software" FILTER "topic = 'programming'"
```
This returns only documents where `topic` equals `'programming'`.
---
## Scenario 2: OpenAI API
Use OpenAI's production embedding service for high-quality semantic search.
### Setup
**1. Set your OpenAI API key:**
```bash
export OPENAI_API_KEY="sk-your-actual-openai-key-here"
```
**2. Start HeroDB:**
```bash
./target/release/herodb --dir ./data --admin-secret my-admin-secret --enable-rpc --rpc-port 8080
```
### Complete Workflow
**Step 1: Create database**
JSON-RPC:
```json
{
"jsonrpc": "2.0",
"id": 1,
"method": "herodb_createDatabase",
"params": [
"Lance",
{ "name": "openai-docs", "storage_path": null, "max_size": null, "redis_version": null },
null
]
}
```
**Step 2: Configure OpenAI embeddings**
JSON-RPC:
```json
{
"jsonrpc": "2.0",
"id": 2,
"method": "herodb_lanceSetEmbeddingConfig",
"params": [
1,
"articles",
{
"provider": "openai",
"model": "text-embedding-3-small",
"dim": 1536,
"endpoint": null,
"headers": {},
"timeout_ms": 30000
}
]
}
```
Redis-like:
```bash
SELECT 1
LANCE.EMBEDDING CONFIG SET articles PROVIDER openai MODEL text-embedding-3-small DIM 1536
```
**Step 3: Insert articles**
JSON-RPC:
```json
{
"jsonrpc": "2.0",
"id": 3,
"method": "herodb_lanceStoreText",
"params": [
1,
"articles",
"article-1",
"Climate change is affecting global weather patterns and ecosystems",
{ "category": "environment", "author": "Jane Smith", "year": "2024" }
]
}
```
```json
{
"jsonrpc": "2.0",
"id": 4,
"method": "herodb_lanceStoreText",
"params": [
1,
"articles",
"article-2",
"Quantum computing promises to revolutionize cryptography and drug discovery",
{ "category": "technology", "author": "John Doe", "year": "2024" }
]
}
```
```json
{
"jsonrpc": "2.0",
"id": 5,
"method": "herodb_lanceStoreText",
"params": [
1,
"articles",
"article-3",
"Renewable energy sources like solar and wind are becoming more cost-effective",
{ "category": "environment", "author": "Alice Johnson", "year": "2023" }
]
}
```
**Step 4: Semantic search**
JSON-RPC:
```json
{
"jsonrpc": "2.0",
"id": 6,
"method": "herodb_lanceSearchText",
"params": [
1,
"articles",
"environmental sustainability and green energy",
2,
null,
["category", "author"]
]
}
```
Redis-like:
```bash
LANCE.SEARCH articles K 2 QUERY "environmental sustainability and green energy" RETURN 2 category author
```
Expected: Returns article-1 and article-3 (both environment-related).
**Step 5: Filtered search**
JSON-RPC:
```json
{
"jsonrpc": "2.0",
"id": 7,
"method": "herodb_lanceSearchText",
"params": [
1,
"articles",
"new technology innovations",
5,
"category = 'technology'",
null
]
}
```
---
## Scenario 3: Deterministic Test Embedder (No Network)
For CI/offline development, use the built-in test embedder that requires no external service.
### Configuration
JSON-RPC:
```json
{
"jsonrpc": "2.0",
"id": 1,
"method": "herodb_lanceSetEmbeddingConfig",
"params": [
1,
"testdata",
{
"provider": "test",
"model": "dev",
"dim": 64,
"endpoint": null,
"headers": {},
"timeout_ms": null
}
]
}
```
Redis-like:
```bash
SELECT 1
LANCE.EMBEDDING CONFIG SET testdata PROVIDER test MODEL dev DIM 64
```
### Usage
Use `lanceStoreText` and `lanceSearchText` as in previous scenarios. The embeddings are:
- Deterministic (same text → same vector)
- Fast (no network)
- Not semantic (hash-based, not ML)
Perfect for testing the vector storage/search mechanics without external dependencies.
---
## Advanced: Custom Headers and Timeouts
### Example: Local model with custom auth
JSON-RPC:
```json
{
"jsonrpc": "2.0",
"id": 1,
"method": "herodb_lanceSetEmbeddingConfig",
"params": [
1,
"secure-data",
{
"provider": "openai",
"model": "custom-model",
"dim": 512,
"endpoint": "http://192.168.1.100:9000/embeddings",
"headers": {
"Authorization": "Bearer my-local-token",
"X-Custom-Header": "value"
},
"timeout_ms": 60000
}
]
}
```
### Example: OpenAI with explicit API key (not from env)
JSON-RPC:
```json
{
"jsonrpc": "2.0",
"id": 1,
"method": "herodb_lanceSetEmbeddingConfig",
"params": [
1,
"dataset",
{
"provider": "openai",
"model": "text-embedding-3-small",
"dim": 1536,
"endpoint": null,
"headers": {
"Authorization": "Bearer sk-your-key-here"
},
"timeout_ms": 30000
}
]
}
```
---
## Troubleshooting
### Error: "Embedding config not set for dataset"
**Cause:** You tried to use `lanceStoreText` or `lanceSearchText` without configuring an embedder.
**Solution:** Run `lanceSetEmbeddingConfig` first.
### Error: "Embedding dimension mismatch: expected X, got Y"
**Cause:** The embedding service returned vectors of a different size than configured.
**Solution:**
- For OpenAI text-embedding-3-small, use `dim: 1536`
- For your local mock (from this tutorial), use `dim: 768`
- Check your embedding service's actual output dimension
### Error: "Missing API key in env 'OPENAI_API_KEY'"
**Cause:** Using OpenAI provider without setting the API key.
**Solution:**
- Set `export OPENAI_API_KEY="sk-..."` before starting HeroDB, OR
- Pass the key explicitly in headers: `"Authorization": "Bearer sk-..."`
### Error: "HTTP request failed" or "Embeddings API error 404"
**Cause:** Cannot reach the embedding endpoint.
**Solution:**
- Verify your local server is running: `curl http://127.0.0.1:8081/v1/embeddings`
- Check the endpoint URL in your config
- Ensure firewall allows the connection
### Error: "ERR DB backend is not Lance"
**Cause:** Trying to use LANCE.* commands on a non-Lance database.
**Solution:** Create the database with backend "Lance" (see Step 1).
### Error: "write permission denied"
**Cause:** Database is private and you haven't authenticated.
**Solution:** Use `SELECT <db_id> KEY <access-key>` or make the database public via RPC.
---
## Complete Example Script (Bash + curl)
Save as `test_embeddings.sh`:
```bash
#!/bin/bash
RPC_URL="http://localhost:8080"
# 1. Create Lance database
curl -X POST $RPC_URL -H "Content-Type: application/json" -d '{
"jsonrpc": "2.0",
"id": 1,
"method": "herodb_createDatabase",
"params": ["Lance", {"name": "test-vectors", "storage_path": null, "max_size": null, "redis_version": null}, null]
}'
echo -e "\n"
# 2. Configure local embedder
curl -X POST $RPC_URL -H "Content-Type: application/json" -d '{
"jsonrpc": "2.0",
"id": 2,
"method": "herodb_lanceSetEmbeddingConfig",
"params": [1, "products", {
"provider": "openai",
"model": "mock",
"dim": 768,
"endpoint": "http://127.0.0.1:8081/v1/embeddings",
"headers": {"Authorization": "Bearer dummy"},
"timeout_ms": 30000
}]
}'
echo -e "\n"
# 3. Insert data
curl -X POST $RPC_URL -H "Content-Type: application/json" -d '{
"jsonrpc": "2.0",
"id": 3,
"method": "herodb_lanceStoreText",
"params": [1, "products", "item-1", "Hiking boots", {"brand": "TrailMax"}]
}'
echo -e "\n"
# 4. Search
curl -X POST $RPC_URL -H "Content-Type: application/json" -d '{
"jsonrpc": "2.0",
"id": 4,
"method": "herodb_lanceSearchText",
"params": [1, "products", "outdoor footwear", 5, null, null]
}'
echo -e "\n"
```
Run:
```bash
chmod +x test_embeddings.sh
./test_embeddings.sh
```
---
## Summary
| Provider | Use Case | Endpoint | API Key |
|----------|----------|----------|---------|
| `openai` | Production semantic search | Default (OpenAI) or custom URL | OPENAI_API_KEY env or headers |
| `openai` | Local self-hosted gateway | http://127.0.0.1:8081/... | Optional (depends on your service) |
| `test` | CI/offline development | N/A (local hash) | None |
| `image_test` | Image testing | N/A (local hash) | None |
**Notes:**
- The `provider` field is always `"openai"` for OpenAI-compatible services (whether cloud or local). This is because it uses the OpenAI-compatible API shape.
- Use `endpoint` to point to your local service
- Use `headers` for custom authentication
- `dim` must match your embedding service's output dimension
- Once configured, `lanceStoreText` and `lanceSearchText` handle embedding automatically

34
mock_embedder.py Normal file
View File

@@ -0,0 +1,34 @@
from flask import Flask, request, jsonify
import numpy as np
app = Flask(__name__)
@app.route('/v1/embeddings', methods=['POST'])
def embeddings():
data = request.json
inputs = data.get('input', [])
if isinstance(inputs, str):
inputs = [inputs]
# Generate deterministic 768-dim embeddings (hash-based)
embeddings = []
for text in inputs:
# Simple hash to vector
vec = np.zeros(768)
for i, char in enumerate(text[:768]):
vec[i % 768] += ord(char) / 255.0
# Normalize
norm = np.linalg.norm(vec)
if norm > 0:
vec = vec / norm
embeddings.append(vec.tolist())
return jsonify({
"data": [{"embedding": emb} for emb in embeddings],
"model": data.get('model', 'mock'),
"usage": {"total_tokens": sum(len(t) for t in inputs)}
})
if __name__ == '__main__':
app.run(host='127.0.0.1', port=8081)

View File

@@ -1,4 +1,4 @@
use crate::{error::DBError, protocol::Protocol, server::Server, embedding::{EmbeddingConfig, EmbeddingProvider}}; use crate::{error::DBError, protocol::Protocol, server::Server, embedding::EmbeddingConfig};
use base64::{engine::general_purpose, Engine as _}; use base64::{engine::general_purpose, Engine as _};
use tokio::time::{timeout, Duration}; use tokio::time::{timeout, Duration};
use futures::future::select_all; use futures::future::select_all;
@@ -170,9 +170,7 @@ pub enum Cmd {
// Embedding configuration per dataset // Embedding configuration per dataset
LanceEmbeddingConfigSet { LanceEmbeddingConfigSet {
name: String, name: String,
provider: String, config: EmbeddingConfig,
model: String,
params: Vec<(String, String)>,
}, },
LanceEmbeddingConfigGet { LanceEmbeddingConfigGet {
name: String, name: String,
@@ -1089,20 +1087,25 @@ impl Cmd {
Cmd::LanceCreateIndex { name, index_type, params } Cmd::LanceCreateIndex { name, index_type, params }
} }
"lance.embedding" => { "lance.embedding" => {
// LANCE.EMBEDDING CONFIG SET name PROVIDER p MODEL m [PARAM k v ...] // LANCE.EMBEDDING CONFIG SET name PROVIDER p MODEL m DIM d [ENDPOINT url] [HEADER k v]... [TIMEOUTMS t]
// LANCE.EMBEDDING CONFIG GET name // LANCE.EMBEDDING CONFIG GET name
if cmd.len() < 3 || cmd[1].to_uppercase() != "CONFIG" { if cmd.len() < 3 || cmd[1].to_uppercase() != "CONFIG" {
return Err(DBError("ERR LANCE.EMBEDDING requires CONFIG subcommand".to_string())); return Err(DBError("ERR LANCE.EMBEDDING requires CONFIG subcommand".to_string()));
} }
if cmd.len() >= 4 && cmd[2].to_uppercase() == "SET" { if cmd.len() >= 4 && cmd[2].to_uppercase() == "SET" {
if cmd.len() < 8 { if cmd.len() < 8 {
return Err(DBError("ERR LANCE.EMBEDDING CONFIG SET requires: SET name PROVIDER p MODEL m [PARAM k v ...]".to_string())); return Err(DBError("ERR LANCE.EMBEDDING CONFIG SET requires: SET name PROVIDER p MODEL m DIM d [ENDPOINT url] [HEADER k v]... [TIMEOUTMS t]".to_string()));
} }
let name = cmd[3].clone(); let name = cmd[3].clone();
let mut i = 4; let mut i = 4;
let mut provider: Option<String> = None; let mut provider: Option<String> = None;
let mut model: Option<String> = None; let mut model: Option<String> = None;
let mut params: Vec<(String, String)> = Vec::new(); let mut dim: Option<usize> = None;
let mut endpoint: Option<String> = None;
let mut headers: std::collections::HashMap<String, String> = std::collections::HashMap::new();
let mut timeout_ms: Option<u64> = None;
while i < cmd.len() { while i < cmd.len() {
match cmd[i].to_uppercase().as_str() { match cmd[i].to_uppercase().as_str() {
"PROVIDER" => { "PROVIDER" => {
@@ -1119,22 +1122,62 @@ impl Cmd {
model = Some(cmd[i + 1].clone()); model = Some(cmd[i + 1].clone());
i += 2; i += 2;
} }
"PARAM" => { "DIM" => {
i += 1; if i + 1 >= cmd.len() {
while i + 1 < cmd.len() { return Err(DBError("ERR DIM requires a value".to_string()));
params.push((cmd[i].clone(), cmd[i + 1].clone()));
i += 2;
} }
let d: usize = cmd[i + 1].parse().map_err(|_| DBError("ERR DIM must be an integer".to_string()))?;
dim = Some(d);
i += 2;
}
"ENDPOINT" => {
if i + 1 >= cmd.len() {
return Err(DBError("ERR ENDPOINT requires a value".to_string()));
}
endpoint = Some(cmd[i + 1].clone());
i += 2;
}
"HEADER" => {
if i + 2 >= cmd.len() {
return Err(DBError("ERR HEADER requires key and value".to_string()));
}
headers.insert(cmd[i + 1].clone(), cmd[i + 2].clone());
i += 3;
}
"TIMEOUTMS" => {
if i + 1 >= cmd.len() {
return Err(DBError("ERR TIMEOUTMS requires a value".to_string()));
}
let t: u64 = cmd[i + 1].parse().map_err(|_| DBError("ERR TIMEOUTMS must be an integer".to_string()))?;
timeout_ms = Some(t);
i += 2;
} }
_ => { _ => {
// Unknown token; break to avoid infinite loop
i += 1; i += 1;
} }
} }
} }
let provider = provider.ok_or_else(|| DBError("ERR missing PROVIDER".to_string()))?;
let provider_str = provider.ok_or_else(|| DBError("ERR missing PROVIDER".to_string()))?;
let provider_enum = match provider_str.to_lowercase().as_str() {
"openai" => crate::embedding::EmbeddingProvider::openai,
"test" => crate::embedding::EmbeddingProvider::test,
"image_test" | "imagetest" | "image-test" => crate::embedding::EmbeddingProvider::image_test,
other => return Err(DBError(format!("ERR unsupported provider '{}'", other))),
};
let model = model.ok_or_else(|| DBError("ERR missing MODEL".to_string()))?; let model = model.ok_or_else(|| DBError("ERR missing MODEL".to_string()))?;
Cmd::LanceEmbeddingConfigSet { name, provider, model, params } let dim = dim.ok_or_else(|| DBError("ERR missing DIM".to_string()))?;
let config = EmbeddingConfig {
provider: provider_enum,
model,
dim,
endpoint,
headers,
timeout_ms,
};
Cmd::LanceEmbeddingConfigSet { name, config }
} else if cmd.len() == 4 && cmd[2].to_uppercase() == "GET" { } else if cmd.len() == 4 && cmd[2].to_uppercase() == "GET" {
let name = cmd[3].clone(); let name = cmd[3].clone();
Cmd::LanceEmbeddingConfigGet { name } Cmd::LanceEmbeddingConfigGet { name }
@@ -1437,25 +1480,14 @@ impl Cmd {
Err(e) => Ok(Protocol::err(&e.0)), Err(e) => Ok(Protocol::err(&e.0)),
} }
} }
Cmd::LanceEmbeddingConfigSet { name, provider, model, params } => { Cmd::LanceEmbeddingConfigSet { name, config } => {
if !server.has_write_permission() { if !server.has_write_permission() {
return Ok(Protocol::err("ERR write permission denied")); return Ok(Protocol::err("ERR write permission denied"));
} }
// Map provider string to enum if config.dim == 0 {
let p_lc = provider.to_lowercase(); return Ok(Protocol::err("ERR embedding DIM must be > 0"));
let prov = match p_lc.as_str() { }
"test-hash" | "testhash" => EmbeddingProvider::TestHash, match server.set_dataset_embedding_config(&name, &config) {
"testimagehash" | "image-test-hash" | "imagetesthash" => EmbeddingProvider::ImageTestHash,
"fastembed" | "lancefastembed" => EmbeddingProvider::LanceFastEmbed,
"openai" | "lanceopenai" => EmbeddingProvider::LanceOpenAI,
other => EmbeddingProvider::LanceOther(other.to_string()),
};
let cfg = EmbeddingConfig {
provider: prov,
model,
params: params.into_iter().collect(),
};
match server.set_dataset_embedding_config(&name, &cfg) {
Ok(()) => Ok(Protocol::SimpleString("OK".to_string())), Ok(()) => Ok(Protocol::SimpleString("OK".to_string())),
Err(e) => Ok(Protocol::err(&e.0)), Err(e) => Ok(Protocol::err(&e.0)),
} }
@@ -1466,16 +1498,20 @@ impl Cmd {
let mut arr = Vec::new(); let mut arr = Vec::new();
arr.push(Protocol::BulkString("provider".to_string())); arr.push(Protocol::BulkString("provider".to_string()));
arr.push(Protocol::BulkString(match cfg.provider { arr.push(Protocol::BulkString(match cfg.provider {
EmbeddingProvider::TestHash => "test-hash".to_string(), crate::embedding::EmbeddingProvider::openai => "openai".to_string(),
EmbeddingProvider::ImageTestHash => "testimagehash".to_string(), crate::embedding::EmbeddingProvider::test => "test".to_string(),
EmbeddingProvider::LanceFastEmbed => "lancefastembed".to_string(), crate::embedding::EmbeddingProvider::image_test => "image_test".to_string(),
EmbeddingProvider::LanceOpenAI => "lanceopenai".to_string(),
EmbeddingProvider::LanceOther(ref s) => s.clone(),
})); }));
arr.push(Protocol::BulkString("model".to_string())); arr.push(Protocol::BulkString("model".to_string()));
arr.push(Protocol::BulkString(cfg.model.clone())); arr.push(Protocol::BulkString(cfg.model.clone()));
arr.push(Protocol::BulkString("params".to_string())); arr.push(Protocol::BulkString("dim".to_string()));
arr.push(Protocol::BulkString(serde_json::to_string(&cfg.params).unwrap_or_else(|_| "{}".to_string()))); arr.push(Protocol::BulkString(cfg.dim.to_string()));
arr.push(Protocol::BulkString("endpoint".to_string()));
arr.push(Protocol::BulkString(cfg.endpoint.clone().unwrap_or_default()));
arr.push(Protocol::BulkString("timeout_ms".to_string()));
arr.push(Protocol::BulkString(cfg.timeout_ms.map(|v| v.to_string()).unwrap_or_default()));
arr.push(Protocol::BulkString("headers".to_string()));
arr.push(Protocol::BulkString(serde_json::to_string(&cfg.headers).unwrap_or_else(|_| "{}".to_string())));
Ok(Protocol::Array(arr)) Ok(Protocol::Array(arr))
} }
Err(e) => Ok(Protocol::err(&e.0)), Err(e) => Ok(Protocol::err(&e.0)),

View File

@@ -1,4 +1,4 @@
// Embedding abstraction and minimal providers. // Embedding abstraction with a single external provider (OpenAI-compatible) and local test providers.
use std::collections::HashMap; use std::collections::HashMap;
use std::sync::Arc; use std::sync::Arc;
@@ -7,42 +7,41 @@ use serde::{Deserialize, Serialize};
use crate::error::DBError; use crate::error::DBError;
// Networking for OpenAI/Azure // Networking for OpenAI-compatible endpoints
use std::time::Duration; use std::time::Duration;
use ureq::{Agent, AgentBuilder}; use ureq::{Agent, AgentBuilder};
use serde_json::json; use serde_json::json;
/// Provider identifiers. Extend as needed to mirror LanceDB-supported providers. /// Provider identifiers (minimal set).
#[derive(Debug, Clone, Serialize, Deserialize, PartialEq, Eq)] #[derive(Debug, Clone, Serialize, Deserialize, PartialEq, Eq)]
#[serde(rename_all = "snake_case")] #[serde(rename_all = "snake_case")]
pub enum EmbeddingProvider { pub enum EmbeddingProvider {
// Deterministic, local-only embedder for CI and offline development (text). /// External HTTP provider compatible with OpenAI's embeddings API.
TestHash, openai,
// Deterministic, local-only embedder for CI and offline development (image). /// Deterministic, local-only embedder for CI and offline development (text).
ImageTestHash, test,
// Placeholders for LanceDB-supported providers; implementers can add concrete backends later. /// Deterministic, local-only embedder for CI and offline development (image).
LanceFastEmbed, image_test,
LanceOpenAI,
LanceOther(String),
} }
/// Serializable embedding configuration. /// Serializable embedding configuration.
/// params: arbitrary key-value map for provider-specific knobs (e.g., "dim", "api_key_env", etc.) /// - provider: "openai" | "test" | "image_test"
/// - model: provider/model id (e.g., "text-embedding-3-small"), may be ignored by local gateways
/// - dim: required output dimension (used to create Lance datasets and validate outputs)
/// - endpoint: optional HTTP endpoint (defaults to OpenAI API when provider == openai)
/// - headers: optional HTTP headers (e.g., Authorization). If empty and OPENAI_API_KEY is present, Authorization will be inferred.
/// - timeout_ms: optional HTTP timeout in milliseconds (for both read and write)
#[derive(Debug, Clone, Serialize, Deserialize)] #[derive(Debug, Clone, Serialize, Deserialize)]
pub struct EmbeddingConfig { pub struct EmbeddingConfig {
pub provider: EmbeddingProvider, pub provider: EmbeddingProvider,
pub model: String, pub model: String,
pub dim: usize,
#[serde(default)] #[serde(default)]
pub params: HashMap<String, String>, pub endpoint: Option<String>,
} #[serde(default)]
pub headers: HashMap<String, String>,
impl EmbeddingConfig { #[serde(default)]
pub fn get_param_usize(&self, key: &str) -> Option<usize> { pub timeout_ms: Option<u64>,
self.params.get(key).and_then(|v| v.parse::<usize>().ok())
}
pub fn get_param_string(&self, key: &str) -> Option<String> {
self.params.get(key).cloned()
}
} }
/// A provider-agnostic text embedding interface. /// A provider-agnostic text embedding interface.
@@ -59,6 +58,20 @@ pub trait Embedder: Send + Sync {
} }
} }
/// Image embedding interface (separate from text to keep modality-specific inputs).
pub trait ImageEmbedder: Send + Sync {
/// Human-readable provider/model name
fn name(&self) -> String;
/// Embedding dimension
fn dim(&self) -> usize;
/// Embed a single image (raw bytes)
fn embed_image(&self, bytes: &[u8]) -> Result<Vec<f32>, DBError>;
/// Embed many images; default maps embed_image() over inputs
fn embed_many_images(&self, images: &[Vec<u8>]) -> Result<Vec<Vec<f32>>, DBError> {
images.iter().map(|b| self.embed_image(b)).collect()
}
}
//// ----------------------------- TEXT: deterministic test embedder ----------------------------- //// ----------------------------- TEXT: deterministic test embedder -----------------------------
/// Deterministic, no-deps, no-network embedder for CI and offline dev. /// Deterministic, no-deps, no-network embedder for CI and offline dev.
@@ -88,7 +101,7 @@ impl TestHashEmbedder {
impl Embedder for TestHashEmbedder { impl Embedder for TestHashEmbedder {
fn name(&self) -> String { fn name(&self) -> String {
format!("test-hash:{}", self.model_name) format!("test:{}", self.model_name)
} }
fn dim(&self) -> usize { fn dim(&self) -> usize {
@@ -117,21 +130,7 @@ impl Embedder for TestHashEmbedder {
} }
} }
//// ----------------------------- IMAGE: trait + deterministic test embedder ----------------------------- //// ----------------------------- IMAGE: deterministic test embedder -----------------------------
/// Image embedding interface (separate from text to keep modality-specific inputs).
pub trait ImageEmbedder: Send + Sync {
/// Human-readable provider/model name
fn name(&self) -> String;
/// Embedding dimension
fn dim(&self) -> usize;
/// Embed a single image (raw bytes)
fn embed_image(&self, bytes: &[u8]) -> Result<Vec<f32>, DBError>;
/// Embed many images; default maps embed_image() over inputs
fn embed_many_images(&self, images: &[Vec<u8>]) -> Result<Vec<Vec<f32>>, DBError> {
images.iter().map(|b| self.embed_image(b)).collect()
}
}
/// Deterministic image embedder that folds bytes into buckets, applies tanh-like nonlinearity, /// Deterministic image embedder that folds bytes into buckets, applies tanh-like nonlinearity,
/// and L2-normalizes. Suitable for CI and offline development. /// and L2-normalizes. Suitable for CI and offline development.
@@ -159,7 +158,7 @@ impl TestImageHashEmbedder {
impl ImageEmbedder for TestImageHashEmbedder { impl ImageEmbedder for TestImageHashEmbedder {
fn name(&self) -> String { fn name(&self) -> String {
format!("test-image-hash:{}", self.model_name) format!("image_test:{}", self.model_name)
} }
fn dim(&self) -> usize { fn dim(&self) -> usize {
@@ -188,80 +187,45 @@ impl ImageEmbedder for TestImageHashEmbedder {
} }
} }
//// OpenAI embedder (supports OpenAI and Azure OpenAI via REST) //// ----------------------------- OpenAI-compatible HTTP embedder -----------------------------
struct OpenAIEmbedder { struct OpenAIEmbedder {
model: String, model: String,
dim: usize, dim: usize,
agent: Agent, agent: Agent,
endpoint: String, endpoint: String,
headers: Vec<(String, String)>, headers: Vec<(String, String)>,
use_azure: bool,
} }
impl OpenAIEmbedder { impl OpenAIEmbedder {
fn new_from_config(cfg: &EmbeddingConfig) -> Result<Self, DBError> { fn new_from_config(cfg: &EmbeddingConfig) -> Result<Self, DBError> {
// Whether to use Azure OpenAI
let use_azure = cfg
.get_param_string("use_azure")
.map(|s| s.eq_ignore_ascii_case("true"))
.unwrap_or(false);
// Resolve API key (OPENAI_API_KEY or AZURE_OPENAI_API_KEY by default)
let api_key_env = cfg
.get_param_string("api_key_env")
.unwrap_or_else(|| {
if use_azure {
"AZURE_OPENAI_API_KEY".to_string()
} else {
"OPENAI_API_KEY".to_string()
}
});
let api_key = std::env::var(&api_key_env)
.map_err(|_| DBError(format!("Missing API key in env '{}'", api_key_env)))?;
// Resolve endpoint // Resolve endpoint
// - Standard OpenAI: https://api.openai.com/v1/embeddings (default) or params["base_url"] let endpoint = cfg.endpoint.clone().unwrap_or_else(|| {
// - Azure OpenAI: {azure_endpoint}/openai/deployments/{deployment}/embeddings?api-version=... "https://api.openai.com/v1/embeddings".to_string()
let endpoint = if use_azure { });
let base = cfg
.get_param_string("azure_endpoint")
.ok_or_else(|| DBError("Missing 'azure_endpoint' for Azure OpenAI".into()))?;
let deployment = cfg
.get_param_string("azure_deployment")
.unwrap_or_else(|| cfg.model.clone());
let api_version = cfg
.get_param_string("azure_api_version")
.unwrap_or_else(|| "2023-05-15".to_string());
format!(
"{}/openai/deployments/{}/embeddings?api-version={}",
base.trim_end_matches('/'),
deployment,
api_version
)
} else {
cfg.get_param_string("base_url")
.unwrap_or_else(|| "https://api.openai.com/v1/embeddings".to_string())
};
// Determine expected dimension (default 1536 for text-embedding-3-small; callers should override if needed) // Determine expected dimension (required by config)
let dim = cfg let dim = cfg.dim;
.get_param_usize("dim")
.or_else(|| cfg.get_param_usize("dimensions"))
.unwrap_or(1536);
// Build an HTTP agent with timeouts (blocking; no tokio runtime involved) // Build an HTTP agent with timeouts (blocking; no tokio runtime involved)
let to_ms = cfg.timeout_ms.unwrap_or(30_000);
let agent = AgentBuilder::new() let agent = AgentBuilder::new()
.timeout_read(Duration::from_secs(30)) .timeout_read(Duration::from_millis(to_ms))
.timeout_write(Duration::from_secs(30)) .timeout_write(Duration::from_millis(to_ms))
.build(); .build();
// Headers // Headers: start from cfg.headers, and add Authorization from env if absent and available
let mut headers: Vec<(String, String)> = Vec::new(); let mut headers: Vec<(String, String)> =
headers.push(("Content-Type".to_string(), "application/json".to_string())); cfg.headers.iter().map(|(k, v)| (k.clone(), v.clone())).collect();
if use_azure {
headers.push(("api-key".to_string(), api_key)); if !headers.iter().any(|(k, _)| k.eq_ignore_ascii_case("content-type")) {
} else { headers.push(("Content-Type".to_string(), "application/json".to_string()));
headers.push(("Authorization".to_string(), format!("Bearer {}", api_key))); }
if !headers.iter().any(|(k, _)| k.eq_ignore_ascii_case("authorization")) {
if let Ok(key) = std::env::var("OPENAI_API_KEY") {
headers.push(("Authorization".to_string(), format!("Bearer {}", key)));
}
} }
Ok(Self { Ok(Self {
@@ -270,19 +234,12 @@ impl OpenAIEmbedder {
agent, agent,
endpoint, endpoint,
headers, headers,
use_azure,
}) })
} }
fn request_many(&self, inputs: &[String]) -> Result<Vec<Vec<f32>>, DBError> { fn request_many(&self, inputs: &[String]) -> Result<Vec<Vec<f32>>, DBError> {
// Compose request body: // Compose request body (OpenAI-compatible)
// - Standard OpenAI: { "model": ..., "input": [...], "dimensions": dim? } let mut body = json!({ "model": self.model, "input": inputs });
// - Azure: { "input": [...], "dimensions": dim? } (model from deployment)
let mut body = if self.use_azure {
json!({ "input": inputs })
} else {
json!({ "model": self.model, "input": inputs })
};
if self.dim > 0 { if self.dim > 0 {
body.as_object_mut() body.as_object_mut()
.unwrap() .unwrap()
@@ -331,7 +288,7 @@ impl OpenAIEmbedder {
} }
if self.dim > 0 && v.len() != self.dim { if self.dim > 0 && v.len() != self.dim {
return Err(DBError(format!( return Err(DBError(format!(
"Embedding dimension mismatch: expected {}, got {}. Configure 'dim' or 'dimensions' to match output.", "Embedding dimension mismatch: expected {}, got {}. Configure 'dim' to match output.",
self.dim, v.len() self.dim, v.len()
))); )));
} }
@@ -343,11 +300,7 @@ impl OpenAIEmbedder {
impl Embedder for OpenAIEmbedder { impl Embedder for OpenAIEmbedder {
fn name(&self) -> String { fn name(&self) -> String {
if self.use_azure { format!("openai:{}", self.model)
format!("azure-openai:{}", self.model)
} else {
format!("openai:{}", self.model)
}
} }
fn dim(&self) -> usize { fn dim(&self) -> usize {
@@ -368,38 +321,33 @@ impl Embedder for OpenAIEmbedder {
} }
/// Create an embedder instance from a config. /// Create an embedder instance from a config.
/// - TestHash: uses params["dim"] or defaults to 64 /// - openai: uses OpenAI-compatible embeddings REST API (endpoint override supported)
/// - LanceOpenAI: uses OpenAI (or Azure OpenAI) embeddings REST API /// - test: deterministic local text embedder (no network)
/// - Other Lance providers can be added similarly /// - image_test: not valid for text (use create_image_embedder)
pub fn create_embedder(config: &EmbeddingConfig) -> Result<Arc<dyn Embedder>, DBError> { pub fn create_embedder(config: &EmbeddingConfig) -> Result<Arc<dyn Embedder>, DBError> {
match &config.provider { match &config.provider {
EmbeddingProvider::TestHash => { EmbeddingProvider::openai => {
let dim = config.get_param_usize("dim").unwrap_or(64);
Ok(Arc::new(TestHashEmbedder::new(dim, config.model.clone())))
}
EmbeddingProvider::LanceOpenAI => {
let inner = OpenAIEmbedder::new_from_config(config)?; let inner = OpenAIEmbedder::new_from_config(config)?;
Ok(Arc::new(inner)) Ok(Arc::new(inner))
} }
EmbeddingProvider::ImageTestHash => { EmbeddingProvider::test => {
Ok(Arc::new(TestHashEmbedder::new(config.dim, config.model.clone())))
}
EmbeddingProvider::image_test => {
Err(DBError("Use create_image_embedder() for image providers".into())) Err(DBError("Use create_image_embedder() for image providers".into()))
} }
EmbeddingProvider::LanceFastEmbed => Err(DBError("LanceFastEmbed provider not yet implemented in Rust embedding layer; configure 'test-hash' or use 'openai'".into())),
EmbeddingProvider::LanceOther(p) => Err(DBError(format!("Lance provider '{}' not implemented; configure 'openai' or 'test-hash'", p))),
} }
} }
/// Create an image embedder instance from a config. /// Create an image embedder instance from a config.
/// - image_test: deterministic local image embedder
pub fn create_image_embedder(config: &EmbeddingConfig) -> Result<Arc<dyn ImageEmbedder>, DBError> { pub fn create_image_embedder(config: &EmbeddingConfig) -> Result<Arc<dyn ImageEmbedder>, DBError> {
match &config.provider { match &config.provider {
EmbeddingProvider::ImageTestHash => { EmbeddingProvider::image_test => {
let dim = config.get_param_usize("dim").unwrap_or(512); Ok(Arc::new(TestImageHashEmbedder::new(config.dim, config.model.clone())))
Ok(Arc::new(TestImageHashEmbedder::new(dim, config.model.clone())))
} }
EmbeddingProvider::TestHash | EmbeddingProvider::LanceOpenAI => { EmbeddingProvider::test | EmbeddingProvider::openai => {
Err(DBError("Configured text provider; dataset expects image provider (e.g., 'testimagehash')".into())) Err(DBError("Configured text provider; dataset expects image provider (e.g., 'image_test')".into()))
} }
EmbeddingProvider::LanceFastEmbed => Err(DBError("Image provider 'lancefastembed' not yet implemented".into())),
EmbeddingProvider::LanceOther(p) => Err(DBError(format!("Image provider '{}' not implemented; use 'testimagehash' for now", p))),
} }
} }

View File

@@ -9,7 +9,7 @@ use sha2::{Digest, Sha256};
use crate::server::Server; use crate::server::Server;
use crate::options::DBOption; use crate::options::DBOption;
use crate::admin_meta; use crate::admin_meta;
use crate::embedding::{EmbeddingConfig, EmbeddingProvider}; use crate::embedding::EmbeddingConfig;
use base64::{engine::general_purpose, Engine as _}; use base64::{engine::general_purpose, Engine as _};
/// Database backend types /// Database backend types
@@ -248,9 +248,7 @@ pub trait Rpc {
&self, &self,
db_id: u64, db_id: u64,
name: String, name: String,
provider: String, config: EmbeddingConfig,
model: String,
params: Option<HashMap<String, String>>,
) -> RpcResult<bool>; ) -> RpcResult<bool>;
/// Get per-dataset embedding configuration /// Get per-dataset embedding configuration
@@ -1008,9 +1006,7 @@ impl RpcServer for RpcServerImpl {
&self, &self,
db_id: u64, db_id: u64,
name: String, name: String,
provider: String, config: EmbeddingConfig,
model: String,
params: Option<HashMap<String, String>>,
) -> RpcResult<bool> { ) -> RpcResult<bool> {
let server = self.get_or_create_server(db_id).await?; let server = self.get_or_create_server(db_id).await?;
if db_id == 0 { if db_id == 0 {
@@ -1022,19 +1018,17 @@ impl RpcServer for RpcServerImpl {
if !server.has_write_permission() { if !server.has_write_permission() {
return Err(jsonrpsee::types::ErrorObjectOwned::owned(-32000, "write permission denied", None::<()>)); return Err(jsonrpsee::types::ErrorObjectOwned::owned(-32000, "write permission denied", None::<()>));
} }
let prov = match provider.to_lowercase().as_str() { // Validate provider and dimension (only a minimal set is allowed for now)
"test-hash" | "testhash" => EmbeddingProvider::TestHash, match config.provider {
"testimagehash" | "image-test-hash" | "imagetesthash" => EmbeddingProvider::ImageTestHash, crate::embedding::EmbeddingProvider::openai
"fastembed" | "lancefastembed" => EmbeddingProvider::LanceFastEmbed, | crate::embedding::EmbeddingProvider::test
"openai" | "lanceopenai" => EmbeddingProvider::LanceOpenAI, | crate::embedding::EmbeddingProvider::image_test => {}
other => EmbeddingProvider::LanceOther(other.to_string()), }
}; if config.dim == 0 {
let cfg = EmbeddingConfig { return Err(jsonrpsee::types::ErrorObjectOwned::owned(-32000, "Invalid embedding config: dim must be > 0", None::<()>));
provider: prov, }
model,
params: params.unwrap_or_default(), server.set_dataset_embedding_config(&name, &config)
};
server.set_dataset_embedding_config(&name, &cfg)
.map_err(|e| jsonrpsee::types::ErrorObjectOwned::owned(-32000, e.0, None::<()>))?; .map_err(|e| jsonrpsee::types::ErrorObjectOwned::owned(-32000, e.0, None::<()>))?;
Ok(true) Ok(true)
} }
@@ -1056,17 +1050,7 @@ impl RpcServer for RpcServerImpl {
} }
let cfg = server.get_dataset_embedding_config(&name) let cfg = server.get_dataset_embedding_config(&name)
.map_err(|e| jsonrpsee::types::ErrorObjectOwned::owned(-32000, e.0, None::<()>))?; .map_err(|e| jsonrpsee::types::ErrorObjectOwned::owned(-32000, e.0, None::<()>))?;
Ok(serde_json::json!({ Ok(serde_json::to_value(&cfg).unwrap_or(serde_json::json!({})))
"provider": match cfg.provider {
EmbeddingProvider::TestHash => "test-hash",
EmbeddingProvider::ImageTestHash => "testimagehash",
EmbeddingProvider::LanceFastEmbed => "lancefastembed",
EmbeddingProvider::LanceOpenAI => "lanceopenai",
EmbeddingProvider::LanceOther(ref s) => s,
},
"model": cfg.model,
"params": cfg.params
}))
} }
async fn lance_store_text( async fn lance_store_text(