feat: Add Jina client training and classification features

- Added `train` function to the Jina client for training
  classifiers.
- Added `ClassificationTrain` struct to define training
  parameters.
- Added `TrainingExample` struct to represent training data.
- Added `ClassificationTrainOutput` struct for the training
  response.
- Added a new `classification_api.v` module for classifier
  training functionalities.
- Added a new `classify` function to the Jina client for
  classification tasks (currently commented out).
This commit is contained in:
Mahmoud Emad
2025-03-11 20:17:35 +02:00
parent 0e1836c5d0
commit 9ecc2444aa
4 changed files with 169 additions and 34 deletions

View File

@@ -4,6 +4,7 @@ import freeflowuniverse.herolib.clients.jina
mut jina_client := jina.get()!
// Create embeddings
embeddings := jina_client.create_embeddings(
input: ['Hello', 'World']
model: .jina_embeddings_v3
@@ -12,6 +13,7 @@ embeddings := jina_client.create_embeddings(
println('Created embeddings: ${embeddings}')
// Rerank
rerank_result := jina_client.rerank(
model: .reranker_v2_base_multilingual
query: 'skincare products'
@@ -20,3 +22,30 @@ rerank_result := jina_client.rerank(
) or { panic('Error while reranking: ${err}') }
println('Rerank result: ${rerank_result}')
// Train
train_result := jina_client.train(
model: .jina_clip_v1
input: [
jina.TrainingExample{
text: 'Sample text'
label: 'positive'
},
jina.TrainingExample{
image: 'https://letsenhance.io/static/73136da51c245e80edc6ccfe44888a99/1015f/MainBefore.jpg'
label: 'negative'
},
]
) or { panic('Error while training: ${err}') }
println('Train result: ${train_result}')
// // Classify
// classification_result := jina_client.classify(
// model: .reranker_v2_base_multilingual
// query: 'skincare products'
// documents: ['Product A', 'Product B', 'Product C']
// top_n: 2
// ) or { panic('Error while classifying: ${err}') }
// println('Classification result: ${classification_result}')

View File

@@ -0,0 +1,132 @@
module jina
import json
import freeflowuniverse.herolib.core.httpconnection
// ClassificationTrainAccess represents the accessibility of the classifier
pub enum ClassificationTrainAccess {
public // Classifier is publicly accessible
private // Classifier is private (default)
}
// TrainingExample represents a single training example (either text or image with a label)
pub struct TrainingExample {
pub mut:
text ?string // Optional text content
image ?string // Optional image URL
label string // Required label
}
// ClassificationTrainOutput represents the response from the training endpoint
pub struct ClassificationTrainOutput {
pub mut:
classifier_id string // Identifier of the trained classifier
num_samples int // Number of samples used in training
usage ClassificationTrainUsage // Token usage details
}
// ClassificationTrainUsage represents token usage for the training request
pub struct ClassificationTrainUsage {
pub mut:
total_tokens int // Total tokens consumed
}
// ClassificationTrain represents parameters for the training request
@[params]
pub struct ClassificationTrain {
pub mut:
model ?JinaModel // Optional model identifier (e.g., jina-clip-v1)
classifier_id ?string // Optional existing classifier ID
access ?ClassificationTrainAccess = .private // Accessibility, defaults to private
input []TrainingExample // Array of training examples
num_iters ?int = 10 // Number of training iterations, defaults to 10
}
// TrainRequest represents the JSON request body for the /v1/train endpoint
struct TrainRequest {
mut:
model ?string
classifier_id ?string
access ?string
input []TrainingExample
num_iters ?int
}
// Train a classifier by sending a POST request to /v1/train
pub fn (mut j Jina) train(params ClassificationTrain) !ClassificationTrainOutput {
// Validate that only one of model or classifier_id is provided
mut model_provided := false
mut classifier_id_provided := false
if _ := params.model {
model_provided = true
}
if _ := params.classifier_id {
classifier_id_provided = true
}
if model_provided && classifier_id_provided {
return error('Provide either model or classifier_id, not both')
}
if model := params.model {
if model == .jina_embeddings_v3 {
return error('jina-embeddings-v3 is not a valid model for classification')
}
}
// Validate each training example has exactly one of text or image
for example in params.input {
mut text_provided := false
mut image_provided := false
if _ := example.text {
text_provided = true
}
if _ := example.image {
image_provided = true
}
if text_provided && image_provided {
return error('Each training example must have either text or image, not both')
}
if !text_provided && !image_provided {
return error('Each training example must have either text or image')
}
}
// Construct the request body
mut request := TrainRequest{
input: params.input
}
if v := params.model {
request.model = v.to_string() // Convert JinaModel enum to string
}
if v := params.classifier_id {
request.classifier_id = v
}
if v := params.access {
request.access = match v {
.public { 'public' }
.private { 'private' }
}
}
if v := params.num_iters {
request.num_iters = v
}
// Create and send the HTTP request
req := httpconnection.Request{
method: .post
prefix: 'v1/train'
dataformat: .json
data: json.encode(request)
}
mut httpclient := j.httpclient()!
response := httpclient.post_json_str(req)!
result := json.decode(ClassificationTrainOutput, response)!
return result
}

View File

@@ -85,6 +85,14 @@ pub fn (mut j Jina) rerank(params RerankParams) !RankingOutput {
return json.decode(RankingOutput, response)!
}
@[params]
pub struct ClassifyParams {
pub mut:
model string @[required] // The classification model
input []string @[required] // Input texts or image URLs
labels []string @[required] // Classification labels
}
// // Create embeddings with a TextDoc input
// pub fn (mut j Jina) create_embeddings_with_docs(args TextEmbeddingInput) !ModelEmbeddingOutput {

View File

@@ -33,30 +33,6 @@ pub mut:
score f64
}
// TrainingExample represents a single training example for classifier training
pub struct TrainingExample {
pub mut:
text string
label string
}
// TrainingAPIInput represents the input for training a classifier
pub struct TrainingAPIInput {
pub mut:
model string @[required]
input []TrainingExample @[required]
access string // Optional: "public" or "private"
}
// TrainingOutput represents the response from training a classifier
pub struct TrainingOutput {
pub mut:
classifier_id string
model string
status string
object string
}
// BulkEmbeddingJobResponse represents the response from bulk embedding operations
pub struct BulkEmbeddingJobResponse {
pub mut:
@@ -148,16 +124,6 @@ pub fn parse_classification_output(json_str string) !ClassificationOutput {
return json.decode(ClassificationOutput, json_str)
}
// Serialize TrainingAPIInput to JSON
pub fn (input TrainingAPIInput) to_json() string {
return json.encode(input)
}
// Parse JSON to TrainingOutput
pub fn parse_training_output(json_str string) !TrainingOutput {
return json.decode(TrainingOutput, json_str)
}
// Parse JSON to BulkEmbeddingJobResponse
pub fn parse_bulk_embedding_job_response(json_str string) !BulkEmbeddingJobResponse {
return json.decode(BulkEmbeddingJobResponse, json_str)