feat: Add Jina client training and classification features
- Added `train` function to the Jina client for training classifiers. - Added `ClassificationTrain` struct to define training parameters. - Added `TrainingExample` struct to represent training data. - Added `ClassificationTrainOutput` struct for the training response. - Added a new `classification_api.v` module for classifier training functionalities. - Added a new `classify` function to the Jina client for classification tasks (currently commented out).
This commit is contained in:
@@ -4,6 +4,7 @@ import freeflowuniverse.herolib.clients.jina
|
|||||||
|
|
||||||
mut jina_client := jina.get()!
|
mut jina_client := jina.get()!
|
||||||
|
|
||||||
|
// Create embeddings
|
||||||
embeddings := jina_client.create_embeddings(
|
embeddings := jina_client.create_embeddings(
|
||||||
input: ['Hello', 'World']
|
input: ['Hello', 'World']
|
||||||
model: .jina_embeddings_v3
|
model: .jina_embeddings_v3
|
||||||
@@ -12,6 +13,7 @@ embeddings := jina_client.create_embeddings(
|
|||||||
|
|
||||||
println('Created embeddings: ${embeddings}')
|
println('Created embeddings: ${embeddings}')
|
||||||
|
|
||||||
|
// Rerank
|
||||||
rerank_result := jina_client.rerank(
|
rerank_result := jina_client.rerank(
|
||||||
model: .reranker_v2_base_multilingual
|
model: .reranker_v2_base_multilingual
|
||||||
query: 'skincare products'
|
query: 'skincare products'
|
||||||
@@ -20,3 +22,30 @@ rerank_result := jina_client.rerank(
|
|||||||
) or { panic('Error while reranking: ${err}') }
|
) or { panic('Error while reranking: ${err}') }
|
||||||
|
|
||||||
println('Rerank result: ${rerank_result}')
|
println('Rerank result: ${rerank_result}')
|
||||||
|
|
||||||
|
// Train
|
||||||
|
train_result := jina_client.train(
|
||||||
|
model: .jina_clip_v1
|
||||||
|
input: [
|
||||||
|
jina.TrainingExample{
|
||||||
|
text: 'Sample text'
|
||||||
|
label: 'positive'
|
||||||
|
},
|
||||||
|
jina.TrainingExample{
|
||||||
|
image: 'https://letsenhance.io/static/73136da51c245e80edc6ccfe44888a99/1015f/MainBefore.jpg'
|
||||||
|
label: 'negative'
|
||||||
|
},
|
||||||
|
]
|
||||||
|
) or { panic('Error while training: ${err}') }
|
||||||
|
|
||||||
|
println('Train result: ${train_result}')
|
||||||
|
|
||||||
|
// // Classify
|
||||||
|
// classification_result := jina_client.classify(
|
||||||
|
// model: .reranker_v2_base_multilingual
|
||||||
|
// query: 'skincare products'
|
||||||
|
// documents: ['Product A', 'Product B', 'Product C']
|
||||||
|
// top_n: 2
|
||||||
|
// ) or { panic('Error while classifying: ${err}') }
|
||||||
|
|
||||||
|
// println('Classification result: ${classification_result}')
|
||||||
|
|||||||
132
lib/clients/jina/classification_api.v
Normal file
132
lib/clients/jina/classification_api.v
Normal file
@@ -0,0 +1,132 @@
|
|||||||
|
module jina
|
||||||
|
|
||||||
|
import json
|
||||||
|
import freeflowuniverse.herolib.core.httpconnection
|
||||||
|
|
||||||
|
// ClassificationTrainAccess represents the accessibility of the classifier
|
||||||
|
pub enum ClassificationTrainAccess {
|
||||||
|
public // Classifier is publicly accessible
|
||||||
|
private // Classifier is private (default)
|
||||||
|
}
|
||||||
|
|
||||||
|
// TrainingExample represents a single training example (either text or image with a label)
|
||||||
|
pub struct TrainingExample {
|
||||||
|
pub mut:
|
||||||
|
text ?string // Optional text content
|
||||||
|
image ?string // Optional image URL
|
||||||
|
label string // Required label
|
||||||
|
}
|
||||||
|
|
||||||
|
// ClassificationTrainOutput represents the response from the training endpoint
|
||||||
|
pub struct ClassificationTrainOutput {
|
||||||
|
pub mut:
|
||||||
|
classifier_id string // Identifier of the trained classifier
|
||||||
|
num_samples int // Number of samples used in training
|
||||||
|
usage ClassificationTrainUsage // Token usage details
|
||||||
|
}
|
||||||
|
|
||||||
|
// ClassificationTrainUsage represents token usage for the training request
|
||||||
|
pub struct ClassificationTrainUsage {
|
||||||
|
pub mut:
|
||||||
|
total_tokens int // Total tokens consumed
|
||||||
|
}
|
||||||
|
|
||||||
|
// ClassificationTrain represents parameters for the training request
|
||||||
|
@[params]
|
||||||
|
pub struct ClassificationTrain {
|
||||||
|
pub mut:
|
||||||
|
model ?JinaModel // Optional model identifier (e.g., jina-clip-v1)
|
||||||
|
classifier_id ?string // Optional existing classifier ID
|
||||||
|
access ?ClassificationTrainAccess = .private // Accessibility, defaults to private
|
||||||
|
input []TrainingExample // Array of training examples
|
||||||
|
num_iters ?int = 10 // Number of training iterations, defaults to 10
|
||||||
|
}
|
||||||
|
|
||||||
|
// TrainRequest represents the JSON request body for the /v1/train endpoint
|
||||||
|
struct TrainRequest {
|
||||||
|
mut:
|
||||||
|
model ?string
|
||||||
|
classifier_id ?string
|
||||||
|
access ?string
|
||||||
|
input []TrainingExample
|
||||||
|
num_iters ?int
|
||||||
|
}
|
||||||
|
|
||||||
|
// Train a classifier by sending a POST request to /v1/train
|
||||||
|
pub fn (mut j Jina) train(params ClassificationTrain) !ClassificationTrainOutput {
|
||||||
|
// Validate that only one of model or classifier_id is provided
|
||||||
|
mut model_provided := false
|
||||||
|
mut classifier_id_provided := false
|
||||||
|
if _ := params.model {
|
||||||
|
model_provided = true
|
||||||
|
}
|
||||||
|
|
||||||
|
if _ := params.classifier_id {
|
||||||
|
classifier_id_provided = true
|
||||||
|
}
|
||||||
|
|
||||||
|
if model_provided && classifier_id_provided {
|
||||||
|
return error('Provide either model or classifier_id, not both')
|
||||||
|
}
|
||||||
|
|
||||||
|
if model := params.model {
|
||||||
|
if model == .jina_embeddings_v3 {
|
||||||
|
return error('jina-embeddings-v3 is not a valid model for classification')
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// Validate each training example has exactly one of text or image
|
||||||
|
for example in params.input {
|
||||||
|
mut text_provided := false
|
||||||
|
mut image_provided := false
|
||||||
|
|
||||||
|
if _ := example.text {
|
||||||
|
text_provided = true
|
||||||
|
}
|
||||||
|
|
||||||
|
if _ := example.image {
|
||||||
|
image_provided = true
|
||||||
|
}
|
||||||
|
|
||||||
|
if text_provided && image_provided {
|
||||||
|
return error('Each training example must have either text or image, not both')
|
||||||
|
}
|
||||||
|
|
||||||
|
if !text_provided && !image_provided {
|
||||||
|
return error('Each training example must have either text or image')
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// Construct the request body
|
||||||
|
mut request := TrainRequest{
|
||||||
|
input: params.input
|
||||||
|
}
|
||||||
|
if v := params.model {
|
||||||
|
request.model = v.to_string() // Convert JinaModel enum to string
|
||||||
|
}
|
||||||
|
if v := params.classifier_id {
|
||||||
|
request.classifier_id = v
|
||||||
|
}
|
||||||
|
if v := params.access {
|
||||||
|
request.access = match v {
|
||||||
|
.public { 'public' }
|
||||||
|
.private { 'private' }
|
||||||
|
}
|
||||||
|
}
|
||||||
|
if v := params.num_iters {
|
||||||
|
request.num_iters = v
|
||||||
|
}
|
||||||
|
|
||||||
|
// Create and send the HTTP request
|
||||||
|
req := httpconnection.Request{
|
||||||
|
method: .post
|
||||||
|
prefix: 'v1/train'
|
||||||
|
dataformat: .json
|
||||||
|
data: json.encode(request)
|
||||||
|
}
|
||||||
|
|
||||||
|
mut httpclient := j.httpclient()!
|
||||||
|
response := httpclient.post_json_str(req)!
|
||||||
|
result := json.decode(ClassificationTrainOutput, response)!
|
||||||
|
return result
|
||||||
|
}
|
||||||
@@ -85,6 +85,14 @@ pub fn (mut j Jina) rerank(params RerankParams) !RankingOutput {
|
|||||||
return json.decode(RankingOutput, response)!
|
return json.decode(RankingOutput, response)!
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@[params]
|
||||||
|
pub struct ClassifyParams {
|
||||||
|
pub mut:
|
||||||
|
model string @[required] // The classification model
|
||||||
|
input []string @[required] // Input texts or image URLs
|
||||||
|
labels []string @[required] // Classification labels
|
||||||
|
}
|
||||||
|
|
||||||
// // Create embeddings with a TextDoc input
|
// // Create embeddings with a TextDoc input
|
||||||
// pub fn (mut j Jina) create_embeddings_with_docs(args TextEmbeddingInput) !ModelEmbeddingOutput {
|
// pub fn (mut j Jina) create_embeddings_with_docs(args TextEmbeddingInput) !ModelEmbeddingOutput {
|
||||||
|
|
||||||
|
|||||||
@@ -33,30 +33,6 @@ pub mut:
|
|||||||
score f64
|
score f64
|
||||||
}
|
}
|
||||||
|
|
||||||
// TrainingExample represents a single training example for classifier training
|
|
||||||
pub struct TrainingExample {
|
|
||||||
pub mut:
|
|
||||||
text string
|
|
||||||
label string
|
|
||||||
}
|
|
||||||
|
|
||||||
// TrainingAPIInput represents the input for training a classifier
|
|
||||||
pub struct TrainingAPIInput {
|
|
||||||
pub mut:
|
|
||||||
model string @[required]
|
|
||||||
input []TrainingExample @[required]
|
|
||||||
access string // Optional: "public" or "private"
|
|
||||||
}
|
|
||||||
|
|
||||||
// TrainingOutput represents the response from training a classifier
|
|
||||||
pub struct TrainingOutput {
|
|
||||||
pub mut:
|
|
||||||
classifier_id string
|
|
||||||
model string
|
|
||||||
status string
|
|
||||||
object string
|
|
||||||
}
|
|
||||||
|
|
||||||
// BulkEmbeddingJobResponse represents the response from bulk embedding operations
|
// BulkEmbeddingJobResponse represents the response from bulk embedding operations
|
||||||
pub struct BulkEmbeddingJobResponse {
|
pub struct BulkEmbeddingJobResponse {
|
||||||
pub mut:
|
pub mut:
|
||||||
@@ -148,16 +124,6 @@ pub fn parse_classification_output(json_str string) !ClassificationOutput {
|
|||||||
return json.decode(ClassificationOutput, json_str)
|
return json.decode(ClassificationOutput, json_str)
|
||||||
}
|
}
|
||||||
|
|
||||||
// Serialize TrainingAPIInput to JSON
|
|
||||||
pub fn (input TrainingAPIInput) to_json() string {
|
|
||||||
return json.encode(input)
|
|
||||||
}
|
|
||||||
|
|
||||||
// Parse JSON to TrainingOutput
|
|
||||||
pub fn parse_training_output(json_str string) !TrainingOutput {
|
|
||||||
return json.decode(TrainingOutput, json_str)
|
|
||||||
}
|
|
||||||
|
|
||||||
// Parse JSON to BulkEmbeddingJobResponse
|
// Parse JSON to BulkEmbeddingJobResponse
|
||||||
pub fn parse_bulk_embedding_job_response(json_str string) !BulkEmbeddingJobResponse {
|
pub fn parse_bulk_embedding_job_response(json_str string) !BulkEmbeddingJobResponse {
|
||||||
return json.decode(BulkEmbeddingJobResponse, json_str)
|
return json.decode(BulkEmbeddingJobResponse, json_str)
|
||||||
|
|||||||
Reference in New Issue
Block a user