...
This commit is contained in:
		
							
								
								
									
										227
									
								
								packages/ai/codemonkey/src/lib.rs
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										227
									
								
								packages/ai/codemonkey/src/lib.rs
									
									
									
									
									
										Normal file
									
								
							@@ -0,0 +1,227 @@
 | 
			
		||||
use async_trait::async_trait;
 | 
			
		||||
use openrouter_rs::{OpenRouterClient, api::chat::*, types::Role, ChatCompletionResponse}; // Added ChatCompletionResponse here
 | 
			
		||||
use std::env;
 | 
			
		||||
use std::error::Error;
 | 
			
		||||
 | 
			
		||||
// Re-export Message and MessageRole for easier use in client code
 | 
			
		||||
pub use openrouter_rs::api::chat::Message;
 | 
			
		||||
pub use openrouter_rs::types::Role as MessageRole;
 | 
			
		||||
// Removed the problematic import for ChatCompletionResponse
 | 
			
		||||
// pub use openrouter_rs::api::chat::chat_completion::ChatCompletionResponse; 
 | 
			
		||||
 | 
			
		||||
#[async_trait]
 | 
			
		||||
pub trait AIProvider {
 | 
			
		||||
    async fn completion(
 | 
			
		||||
        &mut self,
 | 
			
		||||
        request: CompletionRequest,
 | 
			
		||||
    ) -> Result<ChatCompletionResponse, Box<dyn Error>>;
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
pub struct CompletionRequest {
 | 
			
		||||
    pub model: String,
 | 
			
		||||
    pub messages: Vec<Message>,
 | 
			
		||||
    pub temperature: Option<f64>,
 | 
			
		||||
    pub max_tokens: Option<i64>,
 | 
			
		||||
    pub top_p: Option<f64>,
 | 
			
		||||
    pub stream: Option<bool>,
 | 
			
		||||
    pub stop: Option<Vec<String>>,
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
pub struct CompletionRequestBuilder<'a> {
 | 
			
		||||
    provider: &'a mut dyn AIProvider,
 | 
			
		||||
    model: String,
 | 
			
		||||
    messages: Vec<Message>,
 | 
			
		||||
    temperature: Option<f64>,
 | 
			
		||||
    max_tokens: Option<i64>,
 | 
			
		||||
    top_p: Option<f64>,
 | 
			
		||||
    stream: Option<bool>,
 | 
			
		||||
    stop: Option<Vec<String>>,
 | 
			
		||||
    provider_type: AIProviderType,
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
impl<'a> CompletionRequestBuilder<'a> {
 | 
			
		||||
    pub fn new(provider: &'a mut dyn AIProvider, model: String, messages: Vec<Message>, provider_type: AIProviderType) -> Self {
 | 
			
		||||
        Self {
 | 
			
		||||
            provider,
 | 
			
		||||
            model,
 | 
			
		||||
            messages,
 | 
			
		||||
            temperature: None,
 | 
			
		||||
            max_tokens: None,
 | 
			
		||||
            top_p: None,
 | 
			
		||||
            stream: None,
 | 
			
		||||
            stop: None,
 | 
			
		||||
            provider_type,
 | 
			
		||||
        }
 | 
			
		||||
    }
 | 
			
		||||
 | 
			
		||||
    pub fn temperature(mut self, temperature: f64) -> Self {
 | 
			
		||||
        self.temperature = Some(temperature);
 | 
			
		||||
        self
 | 
			
		||||
    }
 | 
			
		||||
 | 
			
		||||
    pub fn max_tokens(mut self, max_tokens: i64) -> Self {
 | 
			
		||||
        self.max_tokens = Some(max_tokens);
 | 
			
		||||
        self
 | 
			
		||||
    }
 | 
			
		||||
 | 
			
		||||
    pub fn top_p(mut self, top_p: f64) -> Self {
 | 
			
		||||
        self.top_p = Some(top_p);
 | 
			
		||||
        self
 | 
			
		||||
    }
 | 
			
		||||
 | 
			
		||||
    pub fn stream(mut self, stream: bool) -> Self {
 | 
			
		||||
        self.stream = Some(stream);
 | 
			
		||||
        self
 | 
			
		||||
    }
 | 
			
		||||
 | 
			
		||||
    pub fn stop(mut self, stop: Vec<String>) -> Self {
 | 
			
		||||
        self.stop = Some(stop);
 | 
			
		||||
        self
 | 
			
		||||
    }
 | 
			
		||||
 | 
			
		||||
    pub async fn completion(self) -> Result<ChatCompletionResponse, Box<dyn Error>> {
 | 
			
		||||
        let request = CompletionRequest {
 | 
			
		||||
            model: self.model,
 | 
			
		||||
            messages: self.messages,
 | 
			
		||||
            temperature: self.temperature,
 | 
			
		||||
            max_tokens: self.max_tokens,
 | 
			
		||||
            top_p: self.top_p,
 | 
			
		||||
            stream: self.stream,
 | 
			
		||||
            stop: self.stop,
 | 
			
		||||
        };
 | 
			
		||||
        self.provider.completion(request).await
 | 
			
		||||
    }
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
pub struct GroqAIProvider {
 | 
			
		||||
    client: OpenRouterClient,
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
#[async_trait]
 | 
			
		||||
impl AIProvider for GroqAIProvider {
 | 
			
		||||
    async fn completion(
 | 
			
		||||
        &mut self,
 | 
			
		||||
        request: CompletionRequest,
 | 
			
		||||
    ) -> Result<ChatCompletionResponse, Box<dyn Error>> {
 | 
			
		||||
        let chat_request = ChatCompletionRequest::builder()
 | 
			
		||||
            .model(request.model)
 | 
			
		||||
            .messages(request.messages)
 | 
			
		||||
            .temperature(request.temperature.unwrap_or(1.0))
 | 
			
		||||
            .max_tokens(request.max_tokens.map(|x| x as u32).unwrap_or(2048))
 | 
			
		||||
            .top_p(request.top_p.unwrap_or(1.0))
 | 
			
		||||
            .stream(request.stream.unwrap_or(false)) // Corrected to field assignment
 | 
			
		||||
            .stop(request.stop.unwrap_or_default())
 | 
			
		||||
            .build()?;
 | 
			
		||||
 | 
			
		||||
        let result = self.client.send_chat_completion(&chat_request).await?;
 | 
			
		||||
        Ok(result)
 | 
			
		||||
    }
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
pub struct OpenAIProvider {
 | 
			
		||||
    client: OpenRouterClient,
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
#[async_trait]
 | 
			
		||||
impl AIProvider for OpenAIProvider {
 | 
			
		||||
    async fn completion(
 | 
			
		||||
        &mut self,
 | 
			
		||||
        request: CompletionRequest,
 | 
			
		||||
    ) -> Result<ChatCompletionResponse, Box<dyn Error>> {
 | 
			
		||||
        let chat_request = ChatCompletionRequest::builder()
 | 
			
		||||
            .model(request.model)
 | 
			
		||||
            .messages(request.messages)
 | 
			
		||||
            .temperature(request.temperature.unwrap_or(1.0))
 | 
			
		||||
            .max_tokens(request.max_tokens.map(|x| x as u32).unwrap_or(2048))
 | 
			
		||||
            .top_p(request.top_p.unwrap_or(1.0))
 | 
			
		||||
            .stream(request.stream.unwrap_or(false)) // Corrected to field assignment
 | 
			
		||||
            .stop(request.stop.unwrap_or_default())
 | 
			
		||||
            .build()?;
 | 
			
		||||
 | 
			
		||||
        let result = self.client.send_chat_completion(&chat_request).await?;
 | 
			
		||||
        Ok(result)
 | 
			
		||||
    }
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
pub struct OpenRouterAIProvider {
 | 
			
		||||
    client: OpenRouterClient,
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
#[async_trait]
 | 
			
		||||
impl AIProvider for OpenRouterAIProvider {
 | 
			
		||||
    async fn completion(
 | 
			
		||||
        &mut self,
 | 
			
		||||
        request: CompletionRequest,
 | 
			
		||||
    ) -> Result<ChatCompletionResponse, Box<dyn Error>> {
 | 
			
		||||
        let chat_request = ChatCompletionRequest::builder()
 | 
			
		||||
            .model(request.model)
 | 
			
		||||
            .messages(request.messages)
 | 
			
		||||
            .temperature(request.temperature.unwrap_or(1.0))
 | 
			
		||||
            .max_tokens(request.max_tokens.map(|x| x as u32).unwrap_or(2048))
 | 
			
		||||
            .top_p(request.top_p.unwrap_or(1.0))
 | 
			
		||||
            .stream(request.stream.unwrap_or(false)) // Corrected to field assignment
 | 
			
		||||
            .stop(request.stop.unwrap_or_default())
 | 
			
		||||
            .build()?;
 | 
			
		||||
 | 
			
		||||
        let result = self.client.send_chat_completion(&chat_request).await?;
 | 
			
		||||
        Ok(result)
 | 
			
		||||
    }
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
pub struct CerebrasAIProvider {
 | 
			
		||||
    client: OpenRouterClient,
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
#[async_trait]
 | 
			
		||||
impl AIProvider for CerebrasAIProvider {
 | 
			
		||||
    async fn completion(
 | 
			
		||||
        &mut self,
 | 
			
		||||
        request: CompletionRequest,
 | 
			
		||||
    ) -> Result<ChatCompletionResponse, Box<dyn Error>> {
 | 
			
		||||
        let chat_request = ChatCompletionRequest::builder()
 | 
			
		||||
            .model(request.model)
 | 
			
		||||
            .messages(request.messages)
 | 
			
		||||
            .temperature(request.temperature.unwrap_or(1.0))
 | 
			
		||||
            .max_tokens(request.max_tokens.map(|x| x as u32).unwrap_or(2048))
 | 
			
		||||
            .top_p(request.top_p.unwrap_or(1.0))
 | 
			
		||||
            .stream(request.stream.unwrap_or(false)) // Corrected to field assignment
 | 
			
		||||
            .stop(request.stop.unwrap_or_default())
 | 
			
		||||
            .build()?;
 | 
			
		||||
 | 
			
		||||
        let result = self.client.send_chat_completion(&chat_request).await?;
 | 
			
		||||
        Ok(result)
 | 
			
		||||
    }
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
#[derive(PartialEq)]
 | 
			
		||||
pub enum AIProviderType {
 | 
			
		||||
    Groq,
 | 
			
		||||
    OpenAI,
 | 
			
		||||
    OpenRouter,
 | 
			
		||||
    Cerebras,
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
pub fn create_ai_provider(provider_type: AIProviderType) -> Result<(Box<dyn AIProvider>, AIProviderType), Box<dyn Error>> {
 | 
			
		||||
    match provider_type {
 | 
			
		||||
        AIProviderType::Groq => {
 | 
			
		||||
            let api_key = env::var("GROQ_API_KEY")?;
 | 
			
		||||
            let client = OpenRouterClient::builder().api_key(api_key).build()?;
 | 
			
		||||
            Ok((Box::new(GroqAIProvider { client }), AIProviderType::Groq))
 | 
			
		||||
        }
 | 
			
		||||
        AIProviderType::OpenAI => {
 | 
			
		||||
            let api_key = env::var("OPENAI_API_KEY")?;
 | 
			
		||||
            let client = OpenRouterClient::builder().api_key(api_key).build()?;
 | 
			
		||||
            Ok((Box::new(OpenAIProvider { client }), AIProviderType::OpenAI))
 | 
			
		||||
        }
 | 
			
		||||
        AIProviderType::OpenRouter => {
 | 
			
		||||
            let api_key = env::var("OPENROUTER_API_KEY")?;
 | 
			
		||||
            let client = OpenRouterClient::builder().api_key(api_key).build()?;
 | 
			
		||||
            Ok((Box::new(OpenRouterAIProvider { client }), AIProviderType::OpenRouter))
 | 
			
		||||
        }
 | 
			
		||||
        AIProviderType::Cerebras => {
 | 
			
		||||
            let api_key = env::var("CEREBRAS_API_KEY")?;
 | 
			
		||||
            let client = OpenRouterClient::builder().api_key(api_key).build()?;
 | 
			
		||||
            Ok((Box::new(CerebrasAIProvider { client }), AIProviderType::Cerebras))
 | 
			
		||||
        }
 | 
			
		||||
    }
 | 
			
		||||
}
 | 
			
		||||
		Reference in New Issue
	
	Block a user