Skip to content

Commit

Permalink
feat: support anthropic (#66)
Browse files Browse the repository at this point in the history
  • Loading branch information
samuelint authored Oct 1, 2024
1 parent c7a2838 commit 6f337d5
Show file tree
Hide file tree
Showing 5 changed files with 148 additions and 3 deletions.
56 changes: 56 additions & 0 deletions core/src/llm/infrastructure/anthropic_llm_factory.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,56 @@
#[cfg(test)]
#[path = "./anthropic_llm_factory_tests.rs"]
mod anthropic_llm_factory_tests;

use std::{error::Error, sync::Arc};

use anyhow::anyhow;
use langchain_rust::{
language_models::{llm::LLM, options::CallOptions},
llm::Claude,
};

use crate::llm::domain::api_key_service::ApiKeyService;
pub use crate::llm::domain::llm_factory::{CreateLLMParameters, LLMFactory};

pub struct AnthropicLLMFactory {
api_key_service: Arc<dyn ApiKeyService>,
}

impl AnthropicLLMFactory {
pub fn new(api_key_service: Arc<dyn ApiKeyService>) -> Self {
Self { api_key_service }
}
}

#[async_trait::async_trait]
impl LLMFactory for AnthropicLLMFactory {
fn is_compatible(&self, model: &str) -> bool {
model.to_lowercase().contains("anthropic")
}

async fn create(
&self,
parameters: &CreateLLMParameters,
) -> Result<Box<dyn LLM>, Box<dyn Error + Send>> {
let split_vec: Vec<&str> = parameters.model.split(':').collect();

let model = match split_vec.last() {
Some(&model) => model,
None => return Err(anyhow!("Invalid model format").into()),
};
let api_key = self
.api_key_service
.get_api_key("ANTHROPIC_API_KEY")
.await?;

let llm = Claude::default()
.with_model(model)
.with_api_key(api_key)
.with_options(CallOptions {
temperature: parameters.temperature,
..CallOptions::default()
});
Ok(Box::new(llm))
}
}
80 changes: 80 additions & 0 deletions core/src/llm/infrastructure/anthropic_llm_factory_tests.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,80 @@
#[cfg(test)]
mod tests {
use std::sync::Arc;

use mockall::predicate::eq;

pub use crate::llm::domain::llm_factory::CreateLLMParameters;
use crate::llm::{
domain::{api_key_service::MockApiKeyService, llm_factory::LLMFactory},
infrastructure::anthropic_llm_factory::AnthropicLLMFactory,
};

#[test]
fn test_model_containing_openai_is_compatible() {
let api_key_service = Arc::new(MockApiKeyService::new());
let instance = AnthropicLLMFactory::new(api_key_service);

assert!(instance.is_compatible("anthropic"));
}

#[test]
fn test_model_containing_openai_in_uppercase_is_compatible() {
let api_key_service = Arc::new(MockApiKeyService::new());
let instance = AnthropicLLMFactory::new(api_key_service);

assert!(instance.is_compatible("ANTHROPIC"));
}

#[test]
fn test_model_containing_openai_and_model_is_compatible() {
let api_key_service = Arc::new(MockApiKeyService::new());
let instance = AnthropicLLMFactory::new(api_key_service);

assert!(instance.is_compatible("anthropic:claude"));
}

#[test]
fn test_model_not_containing_openai_is_not_compatible() {
let api_key_service = Arc::new(MockApiKeyService::new());
let instance = AnthropicLLMFactory::new(api_key_service);

assert!(!instance.is_compatible("openai"));
}

#[tokio::test]
async fn test_create_with_semicolon_is_created() {
let mut api_key_service = MockApiKeyService::new();
api_key_service
.expect_get_api_key()
.with(eq("ANTHROPIC_API_KEY"))
.returning(|_| Ok("ABC".to_string()));
let instance = AnthropicLLMFactory::new(Arc::new(api_key_service));
let result = instance
.create(&CreateLLMParameters {
model: "anthropic:gpt-4o".to_string(),
..CreateLLMParameters::default()
})
.await;

assert!(result.is_ok());
}

#[tokio::test]
async fn test_create_with_model_is_created() {
let mut api_key_service = MockApiKeyService::new();
api_key_service
.expect_get_api_key()
.with(eq("ANTHROPIC_API_KEY"))
.returning(|_| Ok("ABC".to_string()));
let instance = AnthropicLLMFactory::new(Arc::new(api_key_service));
let result = instance
.create(&CreateLLMParameters {
model: "gpt-4o".to_string(),
..CreateLLMParameters::default()
})
.await;

assert!(result.is_ok());
}
}
1 change: 1 addition & 0 deletions core/src/llm/infrastructure/mod.rs
Original file line number Diff line number Diff line change
@@ -1,2 +1,3 @@
pub mod anthropic_llm_factory;
pub mod llm_factory_router;
pub mod openai_llm_factory;
12 changes: 10 additions & 2 deletions core/src/llm/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,10 @@ use domain::{
api_key_service::{ApiKeyService, ApiKeyServiceImpl},
llm_factory::LLMFactory,
};
use infrastructure::{llm_factory_router::LLMFactoryRouter, openai_llm_factory::OpenAILLMFactory};
use infrastructure::{
anthropic_llm_factory::AnthropicLLMFactory, llm_factory_router::LLMFactoryRouter,
openai_llm_factory::OpenAILLMFactory,
};

use crate::configuration::ConfigurationDIModule;

Expand Down Expand Up @@ -34,7 +37,12 @@ impl LLMDIModule {
let api_key_service = self.get_api_key_service();
let openai_llm_factory: Arc<dyn LLMFactory> =
Arc::new(OpenAILLMFactory::new(Arc::clone(&api_key_service)));
let anthropic_llm_factory: Arc<dyn LLMFactory> =
Arc::new(AnthropicLLMFactory::new(Arc::clone(&api_key_service)));

Arc::new(LLMFactoryRouter::new(vec![openai_llm_factory]))
Arc::new(LLMFactoryRouter::new(vec![
openai_llm_factory,
anthropic_llm_factory,
]))
}
}
2 changes: 1 addition & 1 deletion webapp/src/app.config.ts
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
const api_url = 'http://localhost:1234';
const openai_api_url = `${api_url}/openai/v1`;

const AVAILABLE_LLM_MODELS: readonly [string, ...string[]] = ['openai:gpt-4o', 'openai:gpt-4o-mini'];
const AVAILABLE_LLM_MODELS: readonly [string, ...string[]] = ['openai:gpt-4o', 'openai:gpt-4o-mini', 'anthropic:claude-3-5-sonnet-20240620', 'anthropic:claude-3-opus-20240229'];
const LLM_API_KEYS_KEYS: readonly [string, ...string[]] = ['OPENAI_API_KEY', 'ANTHROPIC_API_KEY'];

const appConfig = {
Expand Down

0 comments on commit 6f337d5

Please sign in to comment.