diff --git a/examples/clients/jina.vsh b/examples/clients/jina.vsh index 51b4d065..4106f53b 100755 --- a/examples/clients/jina.vsh +++ b/examples/clients/jina.vsh @@ -40,12 +40,18 @@ train_result := jina_client.train( println('Train result: ${train_result}') -// // Classify -// classification_result := jina_client.classify( -// model: .reranker_v2_base_multilingual -// query: 'skincare products' -// documents: ['Product A', 'Product B', 'Product C'] -// top_n: 2 -// ) or { panic('Error while classifying: ${err}') } +// Classify +classify_result := jina_client.classify( + model: .jina_clip_v1 + input: [ + jina.ClassificationInput{ + text: 'A photo of a cat' + }, + jina.ClassificationInput{ + image: 'https://letsenhance.io/static/73136da51c245e80edc6ccfe44888a99/1015f/MainBefore.jpg' + }, + ] + labels: ['cat', 'dog'] +) or { panic('Error while classifying: ${err}') } -// println('Classification result: ${classification_result}') +println('Classification result: ${classify_result}') diff --git a/lib/clients/jina/classification_api.v b/lib/clients/jina/classification_api.v index 4e139cb1..23af932a 100644 --- a/lib/clients/jina/classification_api.v +++ b/lib/clients/jina/classification_api.v @@ -130,3 +130,133 @@ pub fn (mut j Jina) train(params ClassificationTrain) !ClassificationTrainOutput result := json.decode(ClassificationTrainOutput, response)! return result } + +// TextDoc represents a text document for classification +pub struct TextDoc { +pub mut: + text string // The text content +} + +// ImageDoc represents an image document for classification +pub struct ImageDoc { +pub mut: + image string // The image URL or base64-encoded string +} + +// ClassificationInput represents a single input for classification (text or image) +pub struct ClassificationInput { +pub mut: + text ?string // Optional text content + image ?string // Optional image content +} + +// ClassificationOutput represents the response from the classify endpoint +pub struct ClassificationOutput { +pub mut: + data []ClassificationResult // List of classification results + usage ClassificationUsage // Token usage details +} + +// ClassificationResult represents a single classification result +pub struct ClassificationResult { +pub mut: + index int // Index of the input + prediction string // Predicted label + score f64 // Confidence score + object string // Type of object (e.g., "classification") + predictions []LabelScore // List of label scores +} + +// LabelScore represents a label and its corresponding score +pub struct LabelScore { +pub mut: + label string // Label name + score f64 // Confidence score +} + +// ClassificationUsage represents token usage for the classification request +pub struct ClassificationUsage { +pub mut: + total_tokens int // Total tokens consumed +} + +// ClassifyRequest represents the JSON request body for the /v1/classify endpoint +struct ClassifyRequest { +mut: + model ?string + classifier_id ?string + input []ClassificationInput + labels []string +} + +// ClassifyParams represents parameters for the classification request +@[params] +pub struct ClassifyParams { +pub mut: + model ?JinaModel // Optional model identifier + classifier_id ?string // Optional classifier ID + input []ClassificationInput // Array of inputs (text or image) + labels []string // List of labels for classification +} + +// Classify inputs by sending a POST request to /v1/classify +pub fn (mut j Jina) classify(params ClassifyParams) !ClassificationOutput { + // Validate that only one of model or classifier_id is provided + mut model_provided := false + mut classifier_id_provided := false + if _ := params.model { + model_provided = true + } + if _ := params.classifier_id { + classifier_id_provided = true + } + if model_provided && classifier_id_provided { + return error('Provide either model or classifier_id, not both') + } + if !model_provided && !classifier_id_provided { + return error('Either model or classifier_id must be provided') + } + + // Validate each input has exactly one of text or image + for input in params.input { + mut text_provided := false + mut image_provided := false + if _ := input.text { + text_provided = true + } + if _ := input.image { + image_provided = true + } + if text_provided && image_provided { + return error('Each input must have either text or image, not both') + } + if !text_provided && !image_provided { + return error('Each input must have either text or image') + } + } + + // Construct the request body + mut request := ClassifyRequest{ + input: params.input + labels: params.labels + } + if v := params.model { + request.model = v.to_string() // Convert JinaModel enum to string + } + if v := params.classifier_id { + request.classifier_id = v + } + + // Create and send the HTTP request + req := httpconnection.Request{ + method: .post + prefix: 'v1/classify' + dataformat: .json + data: json.encode(request) + } + + mut httpclient := j.httpclient()! + response := httpclient.post_json_str(req)! + result := json.decode(ClassificationOutput, response)! + return result +} diff --git a/lib/clients/jina/jina_client.v b/lib/clients/jina/jina_client.v index 92449cbe..5ef4c919 100644 --- a/lib/clients/jina/jina_client.v +++ b/lib/clients/jina/jina_client.v @@ -85,14 +85,6 @@ pub fn (mut j Jina) rerank(params RerankParams) !RankingOutput { return json.decode(RankingOutput, response)! } -@[params] -pub struct ClassifyParams { -pub mut: - model string @[required] // The classification model - input []string @[required] // Input texts or image URLs - labels []string @[required] // Classification labels -} - // // Create embeddings with a TextDoc input // pub fn (mut j Jina) create_embeddings_with_docs(args TextEmbeddingInput) !ModelEmbeddingOutput { diff --git a/lib/clients/jina/jina_client_test.v b/lib/clients/jina/jina_client_test.v index d5dc5e52..575de8be 100644 --- a/lib/clients/jina/jina_client_test.v +++ b/lib/clients/jina/jina_client_test.v @@ -55,3 +55,26 @@ fn test_train() { assert train_result.classifier_id.len > 0 assert train_result.num_samples == 2 } + +fn test_classify() { + time.sleep(1 * time.second) + mut client := setup_client()! + classify_result := client.classify( + model: .jina_clip_v1 + input: [ + ClassificationInput{ + text: 'A photo of a cat' + }, + ClassificationInput{ + image: 'https://letsenhance.io/static/73136da51c245e80edc6ccfe44888a99/1015f/MainBefore.jpg' + }, + ] + labels: ['cat', 'dog'] + ) or { panic('Error while classifying: ${err}') } + + assert classify_result.data.len == 2 + assert classify_result.data[0].prediction in ['cat', 'dog'] + assert classify_result.data[1].prediction in ['cat', 'dog'] + assert classify_result.data[0].object == 'classification' + assert classify_result.data[1].object == 'classification' +} diff --git a/lib/clients/jina/model_embed.v b/lib/clients/jina/model_embed.v index 159296b0..f7b564e4 100644 --- a/lib/clients/jina/model_embed.v +++ b/lib/clients/jina/model_embed.v @@ -207,13 +207,6 @@ pub fn (t TextEmbeddingInput) dumps() !string { // return text_embedding_input_from_raw(raw) // } -// TextDoc represents a document with ID and text for embedding -pub struct TextDoc { -pub mut: - id string - text string -} - // ModelEmbeddingOutput represents the response from embedding requests pub struct ModelEmbeddingOutput { pub mut: diff --git a/lib/clients/jina/model_rank.v b/lib/clients/jina/model_rank.v index e5dc145a..0dc0e43a 100644 --- a/lib/clients/jina/model_rank.v +++ b/lib/clients/jina/model_rank.v @@ -2,37 +2,6 @@ module jina import json -// ClassificationAPIInput represents the input for classification requests -pub struct ClassificationAPIInput { -pub mut: - model string @[required] - input []string @[required] - labels []string @[required] -} - -// ClassificationOutput represents the response from classification requests -pub struct ClassificationOutput { -pub mut: - model string - data []ClassificationData - usage Usage - object string -} - -// ClassificationData represents a single classification result -pub struct ClassificationData { -pub mut: - classifications []Classification - index int -} - -// Classification represents a single label classification with score -pub struct Classification { -pub mut: - label string - score f64 -} - // BulkEmbeddingJobResponse represents the response from bulk embedding operations pub struct BulkEmbeddingJobResponse { pub mut: @@ -114,16 +83,6 @@ pub fn parse_ranking_output(json_str string) !RankingOutput { return json.decode(RankingOutput, json_str) } -// Serialize ClassificationAPIInput to JSON -pub fn (input ClassificationAPIInput) to_json() string { - return json.encode(input) -} - -// Parse JSON to ClassificationOutput -pub fn parse_classification_output(json_str string) !ClassificationOutput { - return json.decode(ClassificationOutput, json_str) -} - // Parse JSON to BulkEmbeddingJobResponse pub fn parse_bulk_embedding_job_response(json_str string) !BulkEmbeddingJobResponse { return json.decode(BulkEmbeddingJobResponse, json_str)