feat: Enhance Jina client with improved classification API
- Update `jina.vsh` example to showcase the new classification API with support for both text and image inputs. This improves the flexibility and usability of the client. - Introduce new structs `TextDoc`, `ImageDoc`, `ClassificationInput`, `ClassificationOutput`, `ClassificationResult`, and `LabelScore` to represent data structures for classification requests and responses. This enhances code clarity and maintainability. - Implement the `classify` function in `jina_client.v` to handle classification requests with support for text and image inputs, model selection, and label specification. This adds a crucial feature to the Jina client. - Add comprehensive unit tests in `jina_client_test.v` to cover the new `classify` function's functionality. This ensures the correctness and robustness of the implemented feature. - Remove redundant code related to old classification API and data structures from `model_embed.v`, `model_rank.v`, and `jina_client.v`. This streamlines the codebase and removes obsolete elements.
This commit is contained in:
@@ -40,12 +40,18 @@ train_result := jina_client.train(
|
||||
|
||||
println('Train result: ${train_result}')
|
||||
|
||||
// // Classify
|
||||
// classification_result := jina_client.classify(
|
||||
// model: .reranker_v2_base_multilingual
|
||||
// query: 'skincare products'
|
||||
// documents: ['Product A', 'Product B', 'Product C']
|
||||
// top_n: 2
|
||||
// ) or { panic('Error while classifying: ${err}') }
|
||||
// Classify
|
||||
classify_result := jina_client.classify(
|
||||
model: .jina_clip_v1
|
||||
input: [
|
||||
jina.ClassificationInput{
|
||||
text: 'A photo of a cat'
|
||||
},
|
||||
jina.ClassificationInput{
|
||||
image: 'https://letsenhance.io/static/73136da51c245e80edc6ccfe44888a99/1015f/MainBefore.jpg'
|
||||
},
|
||||
]
|
||||
labels: ['cat', 'dog']
|
||||
) or { panic('Error while classifying: ${err}') }
|
||||
|
||||
// println('Classification result: ${classification_result}')
|
||||
println('Classification result: ${classify_result}')
|
||||
|
||||
@@ -130,3 +130,133 @@ pub fn (mut j Jina) train(params ClassificationTrain) !ClassificationTrainOutput
|
||||
result := json.decode(ClassificationTrainOutput, response)!
|
||||
return result
|
||||
}
|
||||
|
||||
// TextDoc represents a text document for classification
|
||||
pub struct TextDoc {
|
||||
pub mut:
|
||||
text string // The text content
|
||||
}
|
||||
|
||||
// ImageDoc represents an image document for classification
|
||||
pub struct ImageDoc {
|
||||
pub mut:
|
||||
image string // The image URL or base64-encoded string
|
||||
}
|
||||
|
||||
// ClassificationInput represents a single input for classification (text or image)
|
||||
pub struct ClassificationInput {
|
||||
pub mut:
|
||||
text ?string // Optional text content
|
||||
image ?string // Optional image content
|
||||
}
|
||||
|
||||
// ClassificationOutput represents the response from the classify endpoint
|
||||
pub struct ClassificationOutput {
|
||||
pub mut:
|
||||
data []ClassificationResult // List of classification results
|
||||
usage ClassificationUsage // Token usage details
|
||||
}
|
||||
|
||||
// ClassificationResult represents a single classification result
|
||||
pub struct ClassificationResult {
|
||||
pub mut:
|
||||
index int // Index of the input
|
||||
prediction string // Predicted label
|
||||
score f64 // Confidence score
|
||||
object string // Type of object (e.g., "classification")
|
||||
predictions []LabelScore // List of label scores
|
||||
}
|
||||
|
||||
// LabelScore represents a label and its corresponding score
|
||||
pub struct LabelScore {
|
||||
pub mut:
|
||||
label string // Label name
|
||||
score f64 // Confidence score
|
||||
}
|
||||
|
||||
// ClassificationUsage represents token usage for the classification request
|
||||
pub struct ClassificationUsage {
|
||||
pub mut:
|
||||
total_tokens int // Total tokens consumed
|
||||
}
|
||||
|
||||
// ClassifyRequest represents the JSON request body for the /v1/classify endpoint
|
||||
struct ClassifyRequest {
|
||||
mut:
|
||||
model ?string
|
||||
classifier_id ?string
|
||||
input []ClassificationInput
|
||||
labels []string
|
||||
}
|
||||
|
||||
// ClassifyParams represents parameters for the classification request
|
||||
@[params]
|
||||
pub struct ClassifyParams {
|
||||
pub mut:
|
||||
model ?JinaModel // Optional model identifier
|
||||
classifier_id ?string // Optional classifier ID
|
||||
input []ClassificationInput // Array of inputs (text or image)
|
||||
labels []string // List of labels for classification
|
||||
}
|
||||
|
||||
// Classify inputs by sending a POST request to /v1/classify
|
||||
pub fn (mut j Jina) classify(params ClassifyParams) !ClassificationOutput {
|
||||
// Validate that only one of model or classifier_id is provided
|
||||
mut model_provided := false
|
||||
mut classifier_id_provided := false
|
||||
if _ := params.model {
|
||||
model_provided = true
|
||||
}
|
||||
if _ := params.classifier_id {
|
||||
classifier_id_provided = true
|
||||
}
|
||||
if model_provided && classifier_id_provided {
|
||||
return error('Provide either model or classifier_id, not both')
|
||||
}
|
||||
if !model_provided && !classifier_id_provided {
|
||||
return error('Either model or classifier_id must be provided')
|
||||
}
|
||||
|
||||
// Validate each input has exactly one of text or image
|
||||
for input in params.input {
|
||||
mut text_provided := false
|
||||
mut image_provided := false
|
||||
if _ := input.text {
|
||||
text_provided = true
|
||||
}
|
||||
if _ := input.image {
|
||||
image_provided = true
|
||||
}
|
||||
if text_provided && image_provided {
|
||||
return error('Each input must have either text or image, not both')
|
||||
}
|
||||
if !text_provided && !image_provided {
|
||||
return error('Each input must have either text or image')
|
||||
}
|
||||
}
|
||||
|
||||
// Construct the request body
|
||||
mut request := ClassifyRequest{
|
||||
input: params.input
|
||||
labels: params.labels
|
||||
}
|
||||
if v := params.model {
|
||||
request.model = v.to_string() // Convert JinaModel enum to string
|
||||
}
|
||||
if v := params.classifier_id {
|
||||
request.classifier_id = v
|
||||
}
|
||||
|
||||
// Create and send the HTTP request
|
||||
req := httpconnection.Request{
|
||||
method: .post
|
||||
prefix: 'v1/classify'
|
||||
dataformat: .json
|
||||
data: json.encode(request)
|
||||
}
|
||||
|
||||
mut httpclient := j.httpclient()!
|
||||
response := httpclient.post_json_str(req)!
|
||||
result := json.decode(ClassificationOutput, response)!
|
||||
return result
|
||||
}
|
||||
|
||||
@@ -85,14 +85,6 @@ pub fn (mut j Jina) rerank(params RerankParams) !RankingOutput {
|
||||
return json.decode(RankingOutput, response)!
|
||||
}
|
||||
|
||||
@[params]
|
||||
pub struct ClassifyParams {
|
||||
pub mut:
|
||||
model string @[required] // The classification model
|
||||
input []string @[required] // Input texts or image URLs
|
||||
labels []string @[required] // Classification labels
|
||||
}
|
||||
|
||||
// // Create embeddings with a TextDoc input
|
||||
// pub fn (mut j Jina) create_embeddings_with_docs(args TextEmbeddingInput) !ModelEmbeddingOutput {
|
||||
|
||||
|
||||
@@ -55,3 +55,26 @@ fn test_train() {
|
||||
assert train_result.classifier_id.len > 0
|
||||
assert train_result.num_samples == 2
|
||||
}
|
||||
|
||||
fn test_classify() {
|
||||
time.sleep(1 * time.second)
|
||||
mut client := setup_client()!
|
||||
classify_result := client.classify(
|
||||
model: .jina_clip_v1
|
||||
input: [
|
||||
ClassificationInput{
|
||||
text: 'A photo of a cat'
|
||||
},
|
||||
ClassificationInput{
|
||||
image: 'https://letsenhance.io/static/73136da51c245e80edc6ccfe44888a99/1015f/MainBefore.jpg'
|
||||
},
|
||||
]
|
||||
labels: ['cat', 'dog']
|
||||
) or { panic('Error while classifying: ${err}') }
|
||||
|
||||
assert classify_result.data.len == 2
|
||||
assert classify_result.data[0].prediction in ['cat', 'dog']
|
||||
assert classify_result.data[1].prediction in ['cat', 'dog']
|
||||
assert classify_result.data[0].object == 'classification'
|
||||
assert classify_result.data[1].object == 'classification'
|
||||
}
|
||||
|
||||
@@ -207,13 +207,6 @@ pub fn (t TextEmbeddingInput) dumps() !string {
|
||||
// return text_embedding_input_from_raw(raw)
|
||||
// }
|
||||
|
||||
// TextDoc represents a document with ID and text for embedding
|
||||
pub struct TextDoc {
|
||||
pub mut:
|
||||
id string
|
||||
text string
|
||||
}
|
||||
|
||||
// ModelEmbeddingOutput represents the response from embedding requests
|
||||
pub struct ModelEmbeddingOutput {
|
||||
pub mut:
|
||||
|
||||
@@ -2,37 +2,6 @@ module jina
|
||||
|
||||
import json
|
||||
|
||||
// ClassificationAPIInput represents the input for classification requests
|
||||
pub struct ClassificationAPIInput {
|
||||
pub mut:
|
||||
model string @[required]
|
||||
input []string @[required]
|
||||
labels []string @[required]
|
||||
}
|
||||
|
||||
// ClassificationOutput represents the response from classification requests
|
||||
pub struct ClassificationOutput {
|
||||
pub mut:
|
||||
model string
|
||||
data []ClassificationData
|
||||
usage Usage
|
||||
object string
|
||||
}
|
||||
|
||||
// ClassificationData represents a single classification result
|
||||
pub struct ClassificationData {
|
||||
pub mut:
|
||||
classifications []Classification
|
||||
index int
|
||||
}
|
||||
|
||||
// Classification represents a single label classification with score
|
||||
pub struct Classification {
|
||||
pub mut:
|
||||
label string
|
||||
score f64
|
||||
}
|
||||
|
||||
// BulkEmbeddingJobResponse represents the response from bulk embedding operations
|
||||
pub struct BulkEmbeddingJobResponse {
|
||||
pub mut:
|
||||
@@ -114,16 +83,6 @@ pub fn parse_ranking_output(json_str string) !RankingOutput {
|
||||
return json.decode(RankingOutput, json_str)
|
||||
}
|
||||
|
||||
// Serialize ClassificationAPIInput to JSON
|
||||
pub fn (input ClassificationAPIInput) to_json() string {
|
||||
return json.encode(input)
|
||||
}
|
||||
|
||||
// Parse JSON to ClassificationOutput
|
||||
pub fn parse_classification_output(json_str string) !ClassificationOutput {
|
||||
return json.decode(ClassificationOutput, json_str)
|
||||
}
|
||||
|
||||
// Parse JSON to BulkEmbeddingJobResponse
|
||||
pub fn parse_bulk_embedding_job_response(json_str string) !BulkEmbeddingJobResponse {
|
||||
return json.decode(BulkEmbeddingJobResponse, json_str)
|
||||
|
||||
Reference in New Issue
Block a user