-
-
Notifications
You must be signed in to change notification settings - Fork 0
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
- Loading branch information
Showing
5 changed files
with
148 additions
and
3 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,56 @@ | ||
#[cfg(test)] | ||
#[path = "./anthropic_llm_factory_tests.rs"] | ||
mod anthropic_llm_factory_tests; | ||
|
||
use std::{error::Error, sync::Arc}; | ||
|
||
use anyhow::anyhow; | ||
use langchain_rust::{ | ||
language_models::{llm::LLM, options::CallOptions}, | ||
llm::Claude, | ||
}; | ||
|
||
use crate::llm::domain::api_key_service::ApiKeyService; | ||
pub use crate::llm::domain::llm_factory::{CreateLLMParameters, LLMFactory}; | ||
|
||
pub struct AnthropicLLMFactory { | ||
api_key_service: Arc<dyn ApiKeyService>, | ||
} | ||
|
||
impl AnthropicLLMFactory { | ||
pub fn new(api_key_service: Arc<dyn ApiKeyService>) -> Self { | ||
Self { api_key_service } | ||
} | ||
} | ||
|
||
#[async_trait::async_trait] | ||
impl LLMFactory for AnthropicLLMFactory { | ||
fn is_compatible(&self, model: &str) -> bool { | ||
model.to_lowercase().contains("anthropic") | ||
} | ||
|
||
async fn create( | ||
&self, | ||
parameters: &CreateLLMParameters, | ||
) -> Result<Box<dyn LLM>, Box<dyn Error + Send>> { | ||
let split_vec: Vec<&str> = parameters.model.split(':').collect(); | ||
|
||
let model = match split_vec.last() { | ||
Some(&model) => model, | ||
None => return Err(anyhow!("Invalid model format").into()), | ||
}; | ||
let api_key = self | ||
.api_key_service | ||
.get_api_key("ANTHROPIC_API_KEY") | ||
.await?; | ||
|
||
let llm = Claude::default() | ||
.with_model(model) | ||
.with_api_key(api_key) | ||
.with_options(CallOptions { | ||
temperature: parameters.temperature, | ||
..CallOptions::default() | ||
}); | ||
Ok(Box::new(llm)) | ||
} | ||
} |
80 changes: 80 additions & 0 deletions
80
core/src/llm/infrastructure/anthropic_llm_factory_tests.rs
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,80 @@ | ||
#[cfg(test)] | ||
mod tests { | ||
use std::sync::Arc; | ||
|
||
use mockall::predicate::eq; | ||
|
||
pub use crate::llm::domain::llm_factory::CreateLLMParameters; | ||
use crate::llm::{ | ||
domain::{api_key_service::MockApiKeyService, llm_factory::LLMFactory}, | ||
infrastructure::anthropic_llm_factory::AnthropicLLMFactory, | ||
}; | ||
|
||
#[test] | ||
fn test_model_containing_openai_is_compatible() { | ||
let api_key_service = Arc::new(MockApiKeyService::new()); | ||
let instance = AnthropicLLMFactory::new(api_key_service); | ||
|
||
assert!(instance.is_compatible("anthropic")); | ||
} | ||
|
||
#[test] | ||
fn test_model_containing_openai_in_uppercase_is_compatible() { | ||
let api_key_service = Arc::new(MockApiKeyService::new()); | ||
let instance = AnthropicLLMFactory::new(api_key_service); | ||
|
||
assert!(instance.is_compatible("ANTHROPIC")); | ||
} | ||
|
||
#[test] | ||
fn test_model_containing_openai_and_model_is_compatible() { | ||
let api_key_service = Arc::new(MockApiKeyService::new()); | ||
let instance = AnthropicLLMFactory::new(api_key_service); | ||
|
||
assert!(instance.is_compatible("anthropic:claude")); | ||
} | ||
|
||
#[test] | ||
fn test_model_not_containing_openai_is_not_compatible() { | ||
let api_key_service = Arc::new(MockApiKeyService::new()); | ||
let instance = AnthropicLLMFactory::new(api_key_service); | ||
|
||
assert!(!instance.is_compatible("openai")); | ||
} | ||
|
||
#[tokio::test] | ||
async fn test_create_with_semicolon_is_created() { | ||
let mut api_key_service = MockApiKeyService::new(); | ||
api_key_service | ||
.expect_get_api_key() | ||
.with(eq("ANTHROPIC_API_KEY")) | ||
.returning(|_| Ok("ABC".to_string())); | ||
let instance = AnthropicLLMFactory::new(Arc::new(api_key_service)); | ||
let result = instance | ||
.create(&CreateLLMParameters { | ||
model: "anthropic:gpt-4o".to_string(), | ||
..CreateLLMParameters::default() | ||
}) | ||
.await; | ||
|
||
assert!(result.is_ok()); | ||
} | ||
|
||
#[tokio::test] | ||
async fn test_create_with_model_is_created() { | ||
let mut api_key_service = MockApiKeyService::new(); | ||
api_key_service | ||
.expect_get_api_key() | ||
.with(eq("ANTHROPIC_API_KEY")) | ||
.returning(|_| Ok("ABC".to_string())); | ||
let instance = AnthropicLLMFactory::new(Arc::new(api_key_service)); | ||
let result = instance | ||
.create(&CreateLLMParameters { | ||
model: "gpt-4o".to_string(), | ||
..CreateLLMParameters::default() | ||
}) | ||
.await; | ||
|
||
assert!(result.is_ok()); | ||
} | ||
} |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,2 +1,3 @@ | ||
pub mod anthropic_llm_factory; | ||
pub mod llm_factory_router; | ||
pub mod openai_llm_factory; |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters