feat: Enhance Jina client with improved classification API

- Update `jina.vsh` example to showcase the new classification API
  with support for both text and image inputs. This improves
  the flexibility and usability of the client.
- Introduce new structs `TextDoc`, `ImageDoc`, `ClassificationInput`,
  `ClassificationOutput`, `ClassificationResult`, and `LabelScore`
  to represent data structures for classification requests and
  responses. This enhances code clarity and maintainability.
- Implement the `classify` function in `jina_client.v` to handle
  classification requests with support for text and image inputs,
  model selection, and label specification. This adds a crucial
  feature to the Jina client.
- Add comprehensive unit tests in `jina_client_test.v` to cover
  the new `classify` function's functionality. This ensures the
  correctness and robustness of the implemented feature.
- Remove redundant code related to old classification API and data
  structures from `model_embed.v`, `model_rank.v`, and
  `jina_client.v`. This streamlines the codebase and removes
  obsolete elements.
This commit is contained in:
Mahmoud Emad
2025-03-11 21:11:04 +02:00
parent 1a02dcaf0f
commit ad300c068f
6 changed files with 167 additions and 64 deletions

View File

@@ -40,12 +40,18 @@ train_result := jina_client.train(
println('Train result: ${train_result}')
// // Classify
// classification_result := jina_client.classify(
// model: .reranker_v2_base_multilingual
// query: 'skincare products'
// documents: ['Product A', 'Product B', 'Product C']
// top_n: 2
// ) or { panic('Error while classifying: ${err}') }
// Classify
classify_result := jina_client.classify(
model: .jina_clip_v1
input: [
jina.ClassificationInput{
text: 'A photo of a cat'
},
jina.ClassificationInput{
image: 'https://letsenhance.io/static/73136da51c245e80edc6ccfe44888a99/1015f/MainBefore.jpg'
},
]
labels: ['cat', 'dog']
) or { panic('Error while classifying: ${err}') }
// println('Classification result: ${classification_result}')
println('Classification result: ${classify_result}')

View File

@@ -130,3 +130,133 @@ pub fn (mut j Jina) train(params ClassificationTrain) !ClassificationTrainOutput
result := json.decode(ClassificationTrainOutput, response)!
return result
}
// TextDoc represents a text document for classification
pub struct TextDoc {
pub mut:
text string // The text content
}
// ImageDoc represents an image document for classification
pub struct ImageDoc {
pub mut:
image string // The image URL or base64-encoded string
}
// ClassificationInput represents a single input for classification (text or image)
pub struct ClassificationInput {
pub mut:
text ?string // Optional text content
image ?string // Optional image content
}
// ClassificationOutput represents the response from the classify endpoint
pub struct ClassificationOutput {
pub mut:
data []ClassificationResult // List of classification results
usage ClassificationUsage // Token usage details
}
// ClassificationResult represents a single classification result
pub struct ClassificationResult {
pub mut:
index int // Index of the input
prediction string // Predicted label
score f64 // Confidence score
object string // Type of object (e.g., "classification")
predictions []LabelScore // List of label scores
}
// LabelScore represents a label and its corresponding score
pub struct LabelScore {
pub mut:
label string // Label name
score f64 // Confidence score
}
// ClassificationUsage represents token usage for the classification request
pub struct ClassificationUsage {
pub mut:
total_tokens int // Total tokens consumed
}
// ClassifyRequest represents the JSON request body for the /v1/classify endpoint
struct ClassifyRequest {
mut:
model ?string
classifier_id ?string
input []ClassificationInput
labels []string
}
// ClassifyParams represents parameters for the classification request
@[params]
pub struct ClassifyParams {
pub mut:
model ?JinaModel // Optional model identifier
classifier_id ?string // Optional classifier ID
input []ClassificationInput // Array of inputs (text or image)
labels []string // List of labels for classification
}
// Classify inputs by sending a POST request to /v1/classify
pub fn (mut j Jina) classify(params ClassifyParams) !ClassificationOutput {
// Validate that only one of model or classifier_id is provided
mut model_provided := false
mut classifier_id_provided := false
if _ := params.model {
model_provided = true
}
if _ := params.classifier_id {
classifier_id_provided = true
}
if model_provided && classifier_id_provided {
return error('Provide either model or classifier_id, not both')
}
if !model_provided && !classifier_id_provided {
return error('Either model or classifier_id must be provided')
}
// Validate each input has exactly one of text or image
for input in params.input {
mut text_provided := false
mut image_provided := false
if _ := input.text {
text_provided = true
}
if _ := input.image {
image_provided = true
}
if text_provided && image_provided {
return error('Each input must have either text or image, not both')
}
if !text_provided && !image_provided {
return error('Each input must have either text or image')
}
}
// Construct the request body
mut request := ClassifyRequest{
input: params.input
labels: params.labels
}
if v := params.model {
request.model = v.to_string() // Convert JinaModel enum to string
}
if v := params.classifier_id {
request.classifier_id = v
}
// Create and send the HTTP request
req := httpconnection.Request{
method: .post
prefix: 'v1/classify'
dataformat: .json
data: json.encode(request)
}
mut httpclient := j.httpclient()!
response := httpclient.post_json_str(req)!
result := json.decode(ClassificationOutput, response)!
return result
}

View File

@@ -85,14 +85,6 @@ pub fn (mut j Jina) rerank(params RerankParams) !RankingOutput {
return json.decode(RankingOutput, response)!
}
@[params]
pub struct ClassifyParams {
pub mut:
model string @[required] // The classification model
input []string @[required] // Input texts or image URLs
labels []string @[required] // Classification labels
}
// // Create embeddings with a TextDoc input
// pub fn (mut j Jina) create_embeddings_with_docs(args TextEmbeddingInput) !ModelEmbeddingOutput {

View File

@@ -55,3 +55,26 @@ fn test_train() {
assert train_result.classifier_id.len > 0
assert train_result.num_samples == 2
}
fn test_classify() {
time.sleep(1 * time.second)
mut client := setup_client()!
classify_result := client.classify(
model: .jina_clip_v1
input: [
ClassificationInput{
text: 'A photo of a cat'
},
ClassificationInput{
image: 'https://letsenhance.io/static/73136da51c245e80edc6ccfe44888a99/1015f/MainBefore.jpg'
},
]
labels: ['cat', 'dog']
) or { panic('Error while classifying: ${err}') }
assert classify_result.data.len == 2
assert classify_result.data[0].prediction in ['cat', 'dog']
assert classify_result.data[1].prediction in ['cat', 'dog']
assert classify_result.data[0].object == 'classification'
assert classify_result.data[1].object == 'classification'
}

View File

@@ -207,13 +207,6 @@ pub fn (t TextEmbeddingInput) dumps() !string {
// return text_embedding_input_from_raw(raw)
// }
// TextDoc represents a document with ID and text for embedding
pub struct TextDoc {
pub mut:
id string
text string
}
// ModelEmbeddingOutput represents the response from embedding requests
pub struct ModelEmbeddingOutput {
pub mut:

View File

@@ -2,37 +2,6 @@ module jina
import json
// ClassificationAPIInput represents the input for classification requests
pub struct ClassificationAPIInput {
pub mut:
model string @[required]
input []string @[required]
labels []string @[required]
}
// ClassificationOutput represents the response from classification requests
pub struct ClassificationOutput {
pub mut:
model string
data []ClassificationData
usage Usage
object string
}
// ClassificationData represents a single classification result
pub struct ClassificationData {
pub mut:
classifications []Classification
index int
}
// Classification represents a single label classification with score
pub struct Classification {
pub mut:
label string
score f64
}
// BulkEmbeddingJobResponse represents the response from bulk embedding operations
pub struct BulkEmbeddingJobResponse {
pub mut:
@@ -114,16 +83,6 @@ pub fn parse_ranking_output(json_str string) !RankingOutput {
return json.decode(RankingOutput, json_str)
}
// Serialize ClassificationAPIInput to JSON
pub fn (input ClassificationAPIInput) to_json() string {
return json.encode(input)
}
// Parse JSON to ClassificationOutput
pub fn parse_classification_output(json_str string) !ClassificationOutput {
return json.decode(ClassificationOutput, json_str)
}
// Parse JSON to BulkEmbeddingJobResponse
pub fn parse_bulk_embedding_job_response(json_str string) !BulkEmbeddingJobResponse {
return json.decode(BulkEmbeddingJobResponse, json_str)