From 01b9f28ebb9b0ae88b8fe452e8dc73c0a9d63ba9 Mon Sep 17 00:00:00 2001 From: Samuel Date: Wed, 2 Oct 2024 16:03:53 -0400 Subject: [PATCH 1/8] feat: profile --- core/migration/src/lib.rs | 6 ++ .../m20220101_000006_create_profile_table.rs | 46 +++++++++ ...1_000007_seed_profile_software_engineer.rs | 48 ++++++++++ core/src/api_facade.rs | 11 +++ core/src/app_container.rs | 5 +- core/src/entities/mod.rs | 1 + core/src/entities/prelude.rs | 1 + core/src/entities/profile.rs | 20 ++++ core/src/lib.rs | 1 + core/src/profile/domain/dto.rs | 7 ++ core/src/profile/domain/mod.rs | 6 ++ core/src/profile/domain/profile_repository.rs | 13 +++ .../domain/selected_profile_service.rs | 25 +++++ .../profile/infrastructure/entity_adapter.rs | 22 +++++ core/src/profile/infrastructure/mod.rs | 4 + .../sea_orm_profile_repository.rs | 93 +++++++++++++++++++ core/src/profile/mod.rs | 31 +++++++ 17 files changed, 339 insertions(+), 1 deletion(-) create mode 100644 core/migration/src/m20220101_000006_create_profile_table.rs create mode 100644 core/migration/src/m20220101_000007_seed_profile_software_engineer.rs create mode 100644 core/src/entities/profile.rs create mode 100644 core/src/profile/domain/dto.rs create mode 100644 core/src/profile/domain/mod.rs create mode 100644 core/src/profile/domain/profile_repository.rs create mode 100644 core/src/profile/domain/selected_profile_service.rs create mode 100644 core/src/profile/infrastructure/entity_adapter.rs create mode 100644 core/src/profile/infrastructure/mod.rs create mode 100644 core/src/profile/infrastructure/sea_orm_profile_repository.rs create mode 100644 core/src/profile/mod.rs diff --git a/core/migration/src/lib.rs b/core/migration/src/lib.rs index 8b1edbb..224c5dd 100644 --- a/core/migration/src/lib.rs +++ b/core/migration/src/lib.rs @@ -6,6 +6,10 @@ mod m20220101_000002_create_thread_table; mod m20220101_000003_create_run_table; mod m20220101_000004_create_message_table; mod m20220101_000005_seed_default_llm_configuration; +mod m20220101_000006_create_profile_table; +mod m20220101_000007_seed_profile_software_engineer; + +pub use m20220101_000007_seed_profile_software_engineer::SOFTWARE_ENGINEER_PROFILE_NAME; pub struct Migrator; @@ -18,6 +22,8 @@ impl MigratorTrait for Migrator { Box::new(m20220101_000003_create_run_table::Migration), Box::new(m20220101_000004_create_message_table::Migration), Box::new(m20220101_000005_seed_default_llm_configuration::Migration), + Box::new(m20220101_000006_create_profile_table::Migration), + Box::new(m20220101_000007_seed_profile_software_engineer::Migration), ] } } diff --git a/core/migration/src/m20220101_000006_create_profile_table.rs b/core/migration/src/m20220101_000006_create_profile_table.rs new file mode 100644 index 0000000..8afa3e1 --- /dev/null +++ b/core/migration/src/m20220101_000006_create_profile_table.rs @@ -0,0 +1,46 @@ +use sea_orm_migration::prelude::*; + +#[derive(DeriveMigrationName)] +pub struct Migration; + +#[async_trait::async_trait] +impl MigrationTrait for Migration { + async fn up(&self, manager: &SchemaManager) -> Result<(), DbErr> { + manager + .create_table( + Table::create() + .table(Profile::Table) + .if_not_exists() + .col( + ColumnDef::new(Profile::Id) + .integer() + .not_null() + .auto_increment() + .primary_key(), + ) + .col( + ColumnDef::new(Profile::Name) + .string() + .unique_key() + .not_null(), + ) + .col(ColumnDef::new(Profile::Prompt).string().not_null()) + .to_owned(), + ) + .await + } + + async fn down(&self, manager: &SchemaManager) -> Result<(), DbErr> { + manager + .drop_table(Table::drop().table(Profile::Table).to_owned()) + .await + } +} + +#[derive(DeriveIden)] +pub enum Profile { + Table, + Id, + Name, + Prompt, +} diff --git a/core/migration/src/m20220101_000007_seed_profile_software_engineer.rs b/core/migration/src/m20220101_000007_seed_profile_software_engineer.rs new file mode 100644 index 0000000..fac7be5 --- /dev/null +++ b/core/migration/src/m20220101_000007_seed_profile_software_engineer.rs @@ -0,0 +1,48 @@ +use sea_orm_migration::prelude::*; + +use crate::m20220101_000006_create_profile_table::Profile; + +#[derive(DeriveMigrationName)] +pub struct Migration; + +pub static SOFTWARE_ENGINEER_PROFILE_NAME: &str = "Senior Software Engineer"; +static SOFTWARE_ENGINEER_PROFILE_PROMPT: &str = r#" +# audience +A senior software developer. + +# style +Be straight forward and concise. Only give explanation if asked. + +## References +When the answer contains an external project, dependency, command line tools, application or executable, a library or any external references: ALWAYS provide sources and give an URL to the reference. Prefer sources of how to use and install. + +## Code Format +When asked about code questions, give code example. +Provide library answer only if the question is explicitly about code and a language is specified. +If an existing library (or many libraries) already exist for the question, provide it. +"#; + +#[async_trait::async_trait] +impl MigrationTrait for Migration { + async fn up(&self, manager: &SchemaManager) -> Result<(), DbErr> { + let insert: InsertStatement = Query::insert() + .into_table(Profile::Table) + .columns([Profile::Name, Profile::Prompt]) + .values_panic([ + SOFTWARE_ENGINEER_PROFILE_NAME.into(), + SOFTWARE_ENGINEER_PROFILE_PROMPT.into(), + ]) + .to_owned(); + + manager.exec_stmt(insert).await + } + + async fn down(&self, manager: &SchemaManager) -> Result<(), DbErr> { + let delete = Query::delete() + .from_table(Profile::Table) + .and_where(Expr::col(Profile::Name).eq(SOFTWARE_ENGINEER_PROFILE_NAME)) + .to_owned(); + + manager.exec_stmt(delete).await + } +} diff --git a/core/src/api_facade.rs b/core/src/api_facade.rs index c5da1ae..a039aea 100644 --- a/core/src/api_facade.rs +++ b/core/src/api_facade.rs @@ -3,6 +3,7 @@ use std::sync::Arc; use crate::{ chat_completion::{ChatCompletionMessageDto, ChatCompletionResult, ChatCompletionStream}, configuration::ConfigurationDto, + profile::domain::dto::ProfileDto, AppContainer, }; @@ -64,4 +65,14 @@ impl ApiFacade { configuration_service.upsert(key, value).await } + + pub async fn get_selected_profiles( + &self, + ) -> Result, Box> { + self.container + .profile_module + .get_selected_profiles_service() + .get() + .await + } } diff --git a/core/src/app_container.rs b/core/src/app_container.rs index 0f713ab..0a420fc 100644 --- a/core/src/app_container.rs +++ b/core/src/app_container.rs @@ -3,7 +3,7 @@ use std::{error::Error, sync::Arc}; use crate::{ app_configuration::CoreConfiguration, assistant::AgentDIModule, chat_completion::ChatCompletionDIModule, configuration::ConfigurationDIModule, - infrastructure::sea_orm::ConnectionFactory, llm::LLMDIModule, + infrastructure::sea_orm::ConnectionFactory, llm::LLMDIModule, profile::ProfileDIModule, }; pub struct AppContainer { @@ -12,6 +12,7 @@ pub struct AppContainer { pub configuration_module: Arc, pub llm_module: Arc, pub chat_completion_module: Arc, + pub profile_module: Arc, pub agent_module: AgentDIModule, } @@ -23,6 +24,7 @@ impl AppContainer { let configuration_module = Arc::new(ConfigurationDIModule::new(Arc::clone(&connection))); let llm_module = Arc::new(LLMDIModule::new(configuration_module.clone())); let chat_completion_module = Arc::new(ChatCompletionDIModule::new(Arc::clone(&llm_module))); + let profile_module = Arc::new(ProfileDIModule::new(Arc::clone(&connection))); let agent_module: AgentDIModule = AgentDIModule::new(Arc::clone(&connection), Arc::clone(&chat_completion_module)); @@ -32,6 +34,7 @@ impl AppContainer { configuration_module, llm_module, chat_completion_module, + profile_module, agent_module, }) } diff --git a/core/src/entities/mod.rs b/core/src/entities/mod.rs index 626c767..2cec971 100644 --- a/core/src/entities/mod.rs +++ b/core/src/entities/mod.rs @@ -4,5 +4,6 @@ pub mod prelude; pub mod configuration; pub mod message; +pub mod profile; pub mod run; pub mod thread; diff --git a/core/src/entities/prelude.rs b/core/src/entities/prelude.rs index a5422e5..688a0e1 100644 --- a/core/src/entities/prelude.rs +++ b/core/src/entities/prelude.rs @@ -2,5 +2,6 @@ pub use super::configuration::Entity as Configuration; pub use super::message::Entity as Message; +pub use super::profile::Entity as Profile; pub use super::run::Entity as Run; pub use super::thread::Entity as Thread; diff --git a/core/src/entities/profile.rs b/core/src/entities/profile.rs new file mode 100644 index 0000000..930ce86 --- /dev/null +++ b/core/src/entities/profile.rs @@ -0,0 +1,20 @@ +//! `SeaORM` Entity, @generated by sea-orm-codegen 1.0.1 + +use sea_orm::entity::prelude::*; +use serde::{Deserialize, Serialize}; + +#[derive(Clone, Debug, PartialEq, DeriveEntityModel, Eq, Serialize, Deserialize)] +#[sea_orm(table_name = "profile")] +pub struct Model { + #[sea_orm(primary_key)] + pub id: i32, + #[sea_orm(column_type = "Text", unique)] + pub name: String, + #[sea_orm(column_type = "Text")] + pub prompt: String, +} + +#[derive(Copy, Clone, Debug, EnumIter, DeriveRelation)] +pub enum Relation {} + +impl ActiveModelBehavior for ActiveModel {} diff --git a/core/src/lib.rs b/core/src/lib.rs index e500b95..a8f43ee 100644 --- a/core/src/lib.rs +++ b/core/src/lib.rs @@ -6,6 +6,7 @@ pub mod chat_completion; pub mod configuration; mod infrastructure; mod llm; +pub mod profile; pub mod utils; pub mod entities; diff --git a/core/src/profile/domain/dto.rs b/core/src/profile/domain/dto.rs new file mode 100644 index 0000000..8b3a40a --- /dev/null +++ b/core/src/profile/domain/dto.rs @@ -0,0 +1,7 @@ +use serde::{Deserialize, Serialize}; + +#[derive(Debug, Clone, Serialize, Deserialize, Default)] +pub struct ProfileDto { + pub name: String, + pub prompt: String, +} diff --git a/core/src/profile/domain/mod.rs b/core/src/profile/domain/mod.rs new file mode 100644 index 0000000..0977bcc --- /dev/null +++ b/core/src/profile/domain/mod.rs @@ -0,0 +1,6 @@ +pub mod dto; +mod profile_repository; +mod selected_profile_service; + +pub use profile_repository::ProfileRepository; +pub use selected_profile_service::SelectedProfileService; diff --git a/core/src/profile/domain/profile_repository.rs b/core/src/profile/domain/profile_repository.rs new file mode 100644 index 0000000..28ebb00 --- /dev/null +++ b/core/src/profile/domain/profile_repository.rs @@ -0,0 +1,13 @@ +use mockall::automock; +use std::error::Error; + +use super::dto::ProfileDto; + +#[automock] +#[async_trait::async_trait] +pub trait ProfileRepository: Sync + Send { + async fn find(&self, id: &str) -> Result, Box>; + async fn find_by_name(&self, name: &str) -> Result, Box>; + async fn upsert(&self, model: &ProfileDto) -> Result>; + async fn delete(&self, id: &str) -> Result<(), Box>; +} diff --git a/core/src/profile/domain/selected_profile_service.rs b/core/src/profile/domain/selected_profile_service.rs new file mode 100644 index 0000000..2606de3 --- /dev/null +++ b/core/src/profile/domain/selected_profile_service.rs @@ -0,0 +1,25 @@ +use std::{error::Error, sync::Arc, vec}; + +use migration::SOFTWARE_ENGINEER_PROFILE_NAME; + +use super::{dto::ProfileDto, ProfileRepository}; + +pub struct SelectedProfileService { + profile_repository: Arc, +} + +impl SelectedProfileService { + pub fn new(profile_repository: Arc) -> Self { + Self { profile_repository } + } + + pub async fn get(&self) -> Result, Box> { + let profile = self + .profile_repository + .find_by_name(SOFTWARE_ENGINEER_PROFILE_NAME) + .await? + .unwrap(); + + Ok(vec![profile]) + } +} diff --git a/core/src/profile/infrastructure/entity_adapter.rs b/core/src/profile/infrastructure/entity_adapter.rs new file mode 100644 index 0000000..c72c23a --- /dev/null +++ b/core/src/profile/infrastructure/entity_adapter.rs @@ -0,0 +1,22 @@ +use sea_orm::Set; + +use crate::{entities::profile, profile::domain::dto::ProfileDto}; + +impl From<&profile::Model> for ProfileDto { + fn from(model: &profile::Model) -> Self { + Self { + name: model.name.clone(), + prompt: model.prompt.clone(), + } + } +} + +impl From<&ProfileDto> for profile::ActiveModel { + fn from(model: &ProfileDto) -> Self { + Self { + name: Set(model.name.clone()), + prompt: Set(model.prompt.clone()), + ..Default::default() + } + } +} diff --git a/core/src/profile/infrastructure/mod.rs b/core/src/profile/infrastructure/mod.rs new file mode 100644 index 0000000..a8023ef --- /dev/null +++ b/core/src/profile/infrastructure/mod.rs @@ -0,0 +1,4 @@ +mod entity_adapter; +mod sea_orm_profile_repository; + +pub use sea_orm_profile_repository::SeaOrmProfileRepository; diff --git a/core/src/profile/infrastructure/sea_orm_profile_repository.rs b/core/src/profile/infrastructure/sea_orm_profile_repository.rs new file mode 100644 index 0000000..e2dcd77 --- /dev/null +++ b/core/src/profile/infrastructure/sea_orm_profile_repository.rs @@ -0,0 +1,93 @@ +use crate::entities::profile; +use crate::profile::domain::dto::ProfileDto; +use crate::profile::domain::ProfileRepository; +use anyhow::anyhow; +use sea_orm::sea_query::OnConflict; +use sea_orm::{ColumnTrait, DatabaseConnection, EntityTrait, QueryFilter}; +use std::error::Error; +use std::num::ParseIntError; +use std::sync::Arc; + +pub struct SeaOrmProfileRepository { + connection: Arc, +} + +impl SeaOrmProfileRepository { + pub fn new(connection: Arc) -> Self { + Self { connection } + } +} + +#[async_trait::async_trait] +impl ProfileRepository for SeaOrmProfileRepository { + async fn find(&self, id: &str) -> Result, Box> { + let conn = Arc::clone(&self.connection); + let id: i32 = id.parse().map_err(|e: ParseIntError| anyhow!(e))?; + + let configuration = profile::Entity::find_by_id(id) + .one(conn.as_ref()) + .await + .map_err(|e| anyhow!(e))?; + + if configuration.is_none() { + return Ok(None); + } + + let configuration: ProfileDto = (&configuration.unwrap()).into(); + + Ok(Some(configuration)) + } + + async fn find_by_name(&self, name: &str) -> Result, Box> { + let conn = Arc::clone(&self.connection); + let configuration = profile::Entity::find() + .filter(profile::Column::Name.eq(name)) + .one(conn.as_ref()) + .await + .map_err(|e| anyhow!(e))?; + + if configuration.is_none() { + return Ok(None); + } + + let configuration: ProfileDto = (&configuration.unwrap()).into(); + + Ok(Some(configuration)) + } + + async fn upsert(&self, dto: &ProfileDto) -> Result> { + let conn = Arc::clone(&self.connection); + let model: profile::ActiveModel = dto.into(); + + let on_conflict = OnConflict::column(profile::Column::Name) + .update_column(profile::Column::Prompt) + .to_owned(); + + profile::Entity::insert(model) + .on_conflict(on_conflict) + .exec(conn.as_ref()) + .await + .map_err(|e| anyhow!(e))?; + + let result = profile::Entity::find() + .filter(profile::Column::Name.eq(dto.name.clone())) + .one(conn.as_ref()) + .await + .map_err(|e| anyhow!(e))? + .ok_or_else(|| "Failed to find inserted item") + .map_err(|e| anyhow!(e))?; + + Ok((&result).into()) + } + + async fn delete(&self, id: &str) -> Result<(), Box> { + let id: i32 = id.parse().map_err(|e: ParseIntError| anyhow!(e))?; + + profile::Entity::delete_by_id(id) + .exec(self.connection.as_ref()) + .await + .map_err(|e| anyhow!(e))?; + + Ok(()) + } +} diff --git a/core/src/profile/mod.rs b/core/src/profile/mod.rs new file mode 100644 index 0000000..f607176 --- /dev/null +++ b/core/src/profile/mod.rs @@ -0,0 +1,31 @@ +use std::sync::Arc; + +use domain::{ProfileRepository, SelectedProfileService}; +use infrastructure::SeaOrmProfileRepository; + +pub mod domain; +pub mod infrastructure; + +pub struct ProfileDIModule { + connection: Arc<::sea_orm::DatabaseConnection>, +} + +impl ProfileDIModule { + pub fn new(connection: Arc<::sea_orm::DatabaseConnection>) -> Self { + Self { connection } + } + + fn get_connection(&self) -> Arc<::sea_orm::DatabaseConnection> { + Arc::clone(&self.connection) + } + + pub fn get_profile_repository(&self) -> Arc { + let connection = self.get_connection(); + Arc::new(SeaOrmProfileRepository::new(Arc::clone(&connection))) + } + + pub fn get_selected_profiles_service(&self) -> Arc { + let profile_repository = self.get_profile_repository(); + Arc::new(SelectedProfileService::new(Arc::clone(&profile_repository))) + } +} From d1cf5a8c02e61f1a9d21e9d688271b1c928869ab Mon Sep 17 00:00:00 2001 From: Samuel Date: Wed, 2 Oct 2024 16:44:38 -0400 Subject: [PATCH 2/8] feat: system prompt builder --- core/Cargo.toml | 1 + core/README.md | 49 +++++++++++++++++++ .../profile/domain/computer_info_service.rs | 19 +++++++ core/src/profile/domain/mod.rs | 3 ++ .../profile/domain/system_prompt_builder.rs | 48 ++++++++++++++++++ .../domain/system_prompt_builder_test.rs | 35 +++++++++++++ 6 files changed, 155 insertions(+) create mode 100644 core/README.md create mode 100644 core/src/profile/domain/computer_info_service.rs create mode 100644 core/src/profile/domain/system_prompt_builder.rs create mode 100644 core/src/profile/domain/system_prompt_builder_test.rs diff --git a/core/Cargo.toml b/core/Cargo.toml index 7367f01..dfd11d4 100644 --- a/core/Cargo.toml +++ b/core/Cargo.toml @@ -22,6 +22,7 @@ chrono = "0.4.38" async-stream = "0.3.5" anyhow = "1.0.89" itertools = "0.13.0" +sysinfo = "0.31.4" [dev-dependencies] claim = "0.5.0" diff --git a/core/README.md b/core/README.md new file mode 100644 index 0000000..b4510b6 --- /dev/null +++ b/core/README.md @@ -0,0 +1,49 @@ +# AI Likes Human Core + +## Domains + +### Assistant + +Everything related to Assistant and it's storage. +Thread, Message, Run, Etc.. + +Very closed to OpenAI Assistant API. + +### Chat Completion + +Everything related to LLM calling in the OpenAI Chat Completion API format. + +### Configuration + +App configuration. Mainly storage of various configurations. +For example: + +- Api Keys +- Selected LLM +- LLM Parameters + +### Entities + +Database entities generated by `sea-orm` migration system. + +### infrastructure + +Infrastructure code used by every other domains. + +### llm + +LLM Inference. Implementation of every used LLM in this app. +For example: + +- OpenAI +- Anthropic Claude +- LLamaCPP +- Etc... + +### profile + +User profiles. For example system prompt for a `Software Engineer` + +### utils + +Utilities functions used in every other domains diff --git a/core/src/profile/domain/computer_info_service.rs b/core/src/profile/domain/computer_info_service.rs new file mode 100644 index 0000000..221d858 --- /dev/null +++ b/core/src/profile/domain/computer_info_service.rs @@ -0,0 +1,19 @@ +use sysinfo::System; + +pub struct ComputerInfo { + pub os: String, + pub os_version: String, +} + +impl ComputerInfo { + pub fn to_string(&self) -> String { + format!("Operating System: {} {}", self.os, self.os_version) + } +} + +pub fn get_computer_info() -> ComputerInfo { + ComputerInfo { + os: System::name().unwrap_or("".to_string()), + os_version: System::kernel_version().unwrap_or("".to_string()), + } +} diff --git a/core/src/profile/domain/mod.rs b/core/src/profile/domain/mod.rs index 0977bcc..fc04c2e 100644 --- a/core/src/profile/domain/mod.rs +++ b/core/src/profile/domain/mod.rs @@ -1,6 +1,9 @@ +mod computer_info_service; pub mod dto; mod profile_repository; mod selected_profile_service; +mod system_prompt_builder; pub use profile_repository::ProfileRepository; pub use selected_profile_service::SelectedProfileService; +pub use system_prompt_builder::SystemPromptBuilder; diff --git a/core/src/profile/domain/system_prompt_builder.rs b/core/src/profile/domain/system_prompt_builder.rs new file mode 100644 index 0000000..3677885 --- /dev/null +++ b/core/src/profile/domain/system_prompt_builder.rs @@ -0,0 +1,48 @@ +#[cfg(test)] +#[path = "./system_prompt_builder_test.rs"] +mod system_prompt_builder_test; + +use super::{computer_info_service::get_computer_info, dto::ProfileDto}; + +pub struct SystemPromptBuilder { + prompt_chunks: Vec, +} + +impl SystemPromptBuilder { + pub fn new() -> Self { + Self { + prompt_chunks: vec![], + } + } + + pub fn with_computer_info(self) -> Self { + self.with_section("Computer Info", &get_computer_info().to_string()) + } + + pub fn with_profile(self, profile: &ProfileDto) -> Self { + self.with(&profile.prompt) + } + + pub fn with_section(mut self, title: &str, content: &str) -> Self { + self.prompt_chunks + .push(format!("# {}\n{}\n", title, content)); + + self + } + + pub fn with(mut self, content: &str) -> Self { + self.prompt_chunks.push(format!("{}\n", content)); + + self + } + + pub fn build(mut self) -> String { + self = self.with_personal_assistant_role(); + + self.prompt_chunks.join("\n") + } + + fn with_personal_assistant_role(self) -> Self { + self.with_section("role", "You are a personal assistant.") + } +} diff --git a/core/src/profile/domain/system_prompt_builder_test.rs b/core/src/profile/domain/system_prompt_builder_test.rs new file mode 100644 index 0000000..ae414f1 --- /dev/null +++ b/core/src/profile/domain/system_prompt_builder_test.rs @@ -0,0 +1,35 @@ +#[cfg(test)] +mod tests { + use crate::profile::domain::{dto::ProfileDto, SystemPromptBuilder}; + + #[test] + fn test_prompt_have_personal_assistant_role() { + let builder = SystemPromptBuilder::new(); + + let prompt = builder.build(); + + assert!(prompt.contains("# role\nYou are a personal assistant.")); + } + + #[test] + fn test_prompt_have_computer_info() { + let builder = SystemPromptBuilder::new().with_computer_info(); + + let prompt = builder.build(); + + assert!(prompt.contains("# Computer Info")); + } + + #[test] + fn test_prompt_with_profile() { + let profile = &ProfileDto { + name: "some".to_string(), + prompt: "Your are a super hero.".to_string(), + }; + let builder = SystemPromptBuilder::new().with_profile(profile); + + let prompt = builder.build(); + + assert!(prompt.contains("Your are a super hero.")); + } +} From a3f16c6be0e6faccc187f5d4e5e1a264531e8cc5 Mon Sep 17 00:00:00 2001 From: Samuel Date: Wed, 2 Oct 2024 20:45:43 -0400 Subject: [PATCH 3/8] feat: agent factory --- core/README.md | 3 + core/src/api_facade.rs | 2 +- core/src/app_container.rs | 7 +- core/src/llm/domain/agent/agent_factory.rs | 43 +++++++++++ .../llm/domain/agent/agent_factory_test.rs | 65 +++++++++++++++++ .../llm/domain/agent/base_agent_factory.rs | 19 +++++ .../llm/domain/agent/default_agent_factory.rs | 73 +++++++++++++++++++ .../agent/default_system_prompt_factory.rs | 30 ++++++++ core/src/llm/domain/agent/mod.rs | 6 ++ core/src/llm/domain/mod.rs | 1 + core/src/llm/mod.rs | 42 ++++++++++- .../domain/selected_profile_service.rs | 2 +- .../profile/domain/system_prompt_builder.rs | 18 +++-- .../domain/system_prompt_builder_test.rs | 20 ++++- 14 files changed, 317 insertions(+), 14 deletions(-) create mode 100644 core/src/llm/domain/agent/agent_factory.rs create mode 100644 core/src/llm/domain/agent/agent_factory_test.rs create mode 100644 core/src/llm/domain/agent/base_agent_factory.rs create mode 100644 core/src/llm/domain/agent/default_agent_factory.rs create mode 100644 core/src/llm/domain/agent/default_system_prompt_factory.rs create mode 100644 core/src/llm/domain/agent/mod.rs diff --git a/core/README.md b/core/README.md index b4510b6..8c6b934 100644 --- a/core/README.md +++ b/core/README.md @@ -40,6 +40,9 @@ For example: - LLamaCPP - Etc... +Also implement agent creation. Creation of a custom agent based on existing LLM (using langgraph rust). +Can add, tool, chain of tought, etc.. + ### profile User profiles. For example system prompt for a `Software Engineer` diff --git a/core/src/api_facade.rs b/core/src/api_facade.rs index a039aea..d279973 100644 --- a/core/src/api_facade.rs +++ b/core/src/api_facade.rs @@ -72,7 +72,7 @@ impl ApiFacade { self.container .profile_module .get_selected_profiles_service() - .get() + .find_selected_profiles() .await } } diff --git a/core/src/app_container.rs b/core/src/app_container.rs index 0a420fc..6da41a3 100644 --- a/core/src/app_container.rs +++ b/core/src/app_container.rs @@ -22,9 +22,12 @@ impl AppContainer { let connection: Arc<::sea_orm::DatabaseConnection> = connection_factory.create().await?; let configuration_module = Arc::new(ConfigurationDIModule::new(Arc::clone(&connection))); - let llm_module = Arc::new(LLMDIModule::new(configuration_module.clone())); - let chat_completion_module = Arc::new(ChatCompletionDIModule::new(Arc::clone(&llm_module))); let profile_module = Arc::new(ProfileDIModule::new(Arc::clone(&connection))); + let llm_module = Arc::new(LLMDIModule::new( + configuration_module.clone(), + profile_module.clone(), + )); + let chat_completion_module = Arc::new(ChatCompletionDIModule::new(Arc::clone(&llm_module))); let agent_module: AgentDIModule = AgentDIModule::new(Arc::clone(&connection), Arc::clone(&chat_completion_module)); diff --git a/core/src/llm/domain/agent/agent_factory.rs b/core/src/llm/domain/agent/agent_factory.rs new file mode 100644 index 0000000..a4f5ad4 --- /dev/null +++ b/core/src/llm/domain/agent/agent_factory.rs @@ -0,0 +1,43 @@ +#[cfg(test)] +#[path = "./agent_factory_test.rs"] +mod agent_factory_test; + +use super::base_agent_factory::{BaseAgentFactory, CreateAgentArgs}; +use anyhow::anyhow; +use langchain_rust::chain::Chain; +use std::{error::Error, sync::Arc}; + +pub struct AgentFactory { + agent_factories: Vec>, +} + +impl AgentFactory { + pub fn new(agent_factories: Vec>) -> Self { + Self { agent_factories } + } + + fn find_factory(&self, agent_id: &str) -> Option> { + for factory in &self.agent_factories { + if factory.is_compatible(agent_id) { + return Some(factory.clone()); + } + } + + None + } +} + +impl AgentFactory { + pub async fn create( + &self, + agent_id: &str, + args: &CreateAgentArgs, + ) -> Result, Box> { + let chain = match self.find_factory(agent_id) { + Some(factory) => factory.create(args).await, + None => Err(anyhow!("Agent not found: {}", agent_id).into()), + }; + + chain + } +} diff --git a/core/src/llm/domain/agent/agent_factory_test.rs b/core/src/llm/domain/agent/agent_factory_test.rs new file mode 100644 index 0000000..530c3cd --- /dev/null +++ b/core/src/llm/domain/agent/agent_factory_test.rs @@ -0,0 +1,65 @@ +#[cfg(test)] +mod test { + use langchain_rust::{ + chain::{Chain, ChainError}, + language_models::GenerateResult, + prompt::PromptArgs, + }; + use mockall::{mock, predicate::eq}; + use std::{collections::HashMap, sync::Arc}; + + use crate::llm::domain::agent::{ + base_agent_factory::{CreateAgentArgs, MockBaseAgentFactory}, + AgentFactory, + }; + + mock! { + ChainStub { + } + + impl Clone for ChainStub { + fn clone(&self) -> Self { + ChainStub {} + } + } + + #[async_trait::async_trait] + impl Chain for ChainStub { + async fn call(&self, input_variables: PromptArgs) -> Result { + todo!() + } + + } + } + + #[tokio::test] + async fn test_agent_corresponding_to_id_is_returned() { + let mut factory1 = MockBaseAgentFactory::new(); + let mut factory2 = MockBaseAgentFactory::new(); + factory1 + .expect_is_compatible() + .with(eq("AAA")) + .return_const(true); + factory1.expect_create().returning(|_| { + let mut chain_mock = MockChainStub::new(); + chain_mock.expect_call().returning(|_| { + Ok(GenerateResult { + generation: "some gen result".to_string(), + ..GenerateResult::default() + }) + }); + + Ok(Box::new(chain_mock)) + }); + factory2.expect_is_compatible().return_const(false); + let instance = AgentFactory::new(vec![Arc::new(factory1), Arc::new(factory2)]); + + let result = instance + .create("AAA", &CreateAgentArgs::default()) + .await + .unwrap(); + let gen_result = result.call(HashMap::new()).await.unwrap(); + + assert_eq!(gen_result.generation, "some gen result"); + } +} diff --git a/core/src/llm/domain/agent/base_agent_factory.rs b/core/src/llm/domain/agent/base_agent_factory.rs new file mode 100644 index 0000000..1d25a0f --- /dev/null +++ b/core/src/llm/domain/agent/base_agent_factory.rs @@ -0,0 +1,19 @@ +use std::error::Error; + +use langchain_rust::chain::Chain; +use mockall::automock; + +#[derive(Default)] +pub struct CreateAgentArgs { + pub model: String, + pub temperature: Option, +} + +#[automock] +#[async_trait::async_trait] +pub trait BaseAgentFactory { + fn is_compatible(&self, agent_id: &str) -> bool; + + async fn create(&self, args: &CreateAgentArgs) + -> Result, Box>; +} diff --git a/core/src/llm/domain/agent/default_agent_factory.rs b/core/src/llm/domain/agent/default_agent_factory.rs new file mode 100644 index 0000000..a0d727b --- /dev/null +++ b/core/src/llm/domain/agent/default_agent_factory.rs @@ -0,0 +1,73 @@ +use std::{error::Error, sync::Arc}; + +use langchain_rust::{ + agent::{AgentExecutor, ConversationalAgentBuilder}, + chain::{options::ChainCallOptions, Chain}, +}; + +use crate::llm::domain::llm_factory::{CreateLLMParameters, LLMFactory}; + +use super::{ + base_agent_factory::{BaseAgentFactory, CreateAgentArgs}, + default_system_prompt_factory::DefaultSystemPromptFactory, +}; + +pub struct DefaultAgentFactory { + llm_factory: Arc, + system_prompt_factory: Arc, +} + +impl DefaultAgentFactory { + pub fn new( + llm_factory: Arc, + system_prompt_factory: Arc, + ) -> Self { + Self { + llm_factory, + system_prompt_factory, + } + } + + fn create_options(args: &CreateAgentArgs) -> ChainCallOptions { + let options: ChainCallOptions = ChainCallOptions::new(); + let options = match args.temperature { + Some(temperature) => options.with_temperature(temperature), + None => options, + }; + + options + } +} + +#[async_trait::async_trait] +impl BaseAgentFactory for DefaultAgentFactory { + fn is_compatible(&self, _agent_id: &str) -> bool { + true + } + + async fn create( + &self, + args: &CreateAgentArgs, + ) -> Result, Box> { + let system_prompt = self.system_prompt_factory.create().await?; + let llm = self + .llm_factory + .create(&CreateLLMParameters { + model: args.model.to_string(), + temperature: args.temperature, + ..Default::default() + }) + .await?; + + let agent = ConversationalAgentBuilder::new() + // .tools(&[Arc::new(command_executor)]) + .options(Self::create_options(args)) + .prefix(system_prompt) + .build(llm) + .unwrap(); + + let executor = AgentExecutor::from_agent(agent); + + Ok(Box::new(executor)) + } +} diff --git a/core/src/llm/domain/agent/default_system_prompt_factory.rs b/core/src/llm/domain/agent/default_system_prompt_factory.rs new file mode 100644 index 0000000..c3d8bd6 --- /dev/null +++ b/core/src/llm/domain/agent/default_system_prompt_factory.rs @@ -0,0 +1,30 @@ +use std::{error::Error, sync::Arc}; + +use crate::profile::domain::{SelectedProfileService, SystemPromptBuilder}; + +pub struct DefaultSystemPromptFactory { + selected_profile_service: Arc, +} + +impl DefaultSystemPromptFactory { + pub fn new(selected_profile_service: Arc) -> Self { + Self { + selected_profile_service, + } + } + + pub async fn create(&self) -> Result> { + let profiles = self + .selected_profile_service + .find_selected_profiles() + .await?; + + let system_prompt = SystemPromptBuilder::new() + .with_personal_assistant_role() + .with_computer_info() + .with_profiles(&profiles) + .build(); + + Ok(system_prompt) + } +} diff --git a/core/src/llm/domain/agent/mod.rs b/core/src/llm/domain/agent/mod.rs new file mode 100644 index 0000000..25eb3b1 --- /dev/null +++ b/core/src/llm/domain/agent/mod.rs @@ -0,0 +1,6 @@ +pub mod agent_factory; +pub mod base_agent_factory; +pub mod default_agent_factory; +pub mod default_system_prompt_factory; + +pub use agent_factory::AgentFactory; diff --git a/core/src/llm/domain/mod.rs b/core/src/llm/domain/mod.rs index 063e172..9b6c619 100644 --- a/core/src/llm/domain/mod.rs +++ b/core/src/llm/domain/mod.rs @@ -1,3 +1,4 @@ +pub mod agent; pub mod api_key_service; pub mod llm_factory; pub mod message_type_adapter; diff --git a/core/src/llm/mod.rs b/core/src/llm/mod.rs index 5188649..77b230b 100644 --- a/core/src/llm/mod.rs +++ b/core/src/llm/mod.rs @@ -1,6 +1,10 @@ use std::sync::Arc; use domain::{ + agent::{ + base_agent_factory::BaseAgentFactory, default_agent_factory::DefaultAgentFactory, + default_system_prompt_factory::DefaultSystemPromptFactory, AgentFactory, + }, api_key_service::{ApiKeyService, ApiKeyServiceImpl}, llm_factory::LLMFactory, }; @@ -9,19 +13,24 @@ use infrastructure::{ openai_llm_factory::OpenAILLMFactory, }; -use crate::configuration::ConfigurationDIModule; +use crate::{configuration::ConfigurationDIModule, profile::ProfileDIModule}; pub mod domain; pub mod infrastructure; pub struct LLMDIModule { configuration_module: Arc, + profile_module: Arc, } impl LLMDIModule { - pub fn new(configuration_module: Arc) -> Self { + pub fn new( + configuration_module: Arc, + profile_module: Arc, + ) -> Self { Self { configuration_module, + profile_module, } } @@ -45,4 +54,33 @@ impl LLMDIModule { anthropic_llm_factory, ])) } + + fn get_default_system_prompt_factory(&self) -> Arc { + let selected_profile_service = self.profile_module.get_selected_profiles_service(); + + Arc::new(DefaultSystemPromptFactory::new( + selected_profile_service.clone(), + )) + } + + fn get_default_agent_factory(&self) -> Arc { + let llm_factory = self.get_llm_factory(); + let system_prompt_factory = self.get_default_system_prompt_factory(); + + let default = DefaultAgentFactory::new(llm_factory, system_prompt_factory); + + Arc::new(default) + } + + fn get_base_agent_factories(&self) -> Vec> { + let default = self.get_default_agent_factory(); + + vec![default] + } + + pub fn get_agent_factory(&self) -> Arc { + let base_factories = self.get_base_agent_factories(); + + Arc::new(AgentFactory::new(base_factories)) + } } diff --git a/core/src/profile/domain/selected_profile_service.rs b/core/src/profile/domain/selected_profile_service.rs index 2606de3..b53137d 100644 --- a/core/src/profile/domain/selected_profile_service.rs +++ b/core/src/profile/domain/selected_profile_service.rs @@ -13,7 +13,7 @@ impl SelectedProfileService { Self { profile_repository } } - pub async fn get(&self) -> Result, Box> { + pub async fn find_selected_profiles(&self) -> Result, Box> { let profile = self .profile_repository .find_by_name(SOFTWARE_ENGINEER_PROFILE_NAME) diff --git a/core/src/profile/domain/system_prompt_builder.rs b/core/src/profile/domain/system_prompt_builder.rs index 3677885..a3345c1 100644 --- a/core/src/profile/domain/system_prompt_builder.rs +++ b/core/src/profile/domain/system_prompt_builder.rs @@ -23,6 +23,12 @@ impl SystemPromptBuilder { self.with(&profile.prompt) } + pub fn with_profiles(self, profiles: &Vec) -> Self { + profiles + .iter() + .fold(self, |acc, profile| acc.with_profile(profile)) + } + pub fn with_section(mut self, title: &str, content: &str) -> Self { self.prompt_chunks .push(format!("# {}\n{}\n", title, content)); @@ -30,19 +36,17 @@ impl SystemPromptBuilder { self } + pub fn with_personal_assistant_role(self) -> Self { + self.with_section("role", "You are a personal assistant.") + } + pub fn with(mut self, content: &str) -> Self { self.prompt_chunks.push(format!("{}\n", content)); self } - pub fn build(mut self) -> String { - self = self.with_personal_assistant_role(); - + pub fn build(self) -> String { self.prompt_chunks.join("\n") } - - fn with_personal_assistant_role(self) -> Self { - self.with_section("role", "You are a personal assistant.") - } } diff --git a/core/src/profile/domain/system_prompt_builder_test.rs b/core/src/profile/domain/system_prompt_builder_test.rs index ae414f1..632304c 100644 --- a/core/src/profile/domain/system_prompt_builder_test.rs +++ b/core/src/profile/domain/system_prompt_builder_test.rs @@ -4,7 +4,7 @@ mod tests { #[test] fn test_prompt_have_personal_assistant_role() { - let builder = SystemPromptBuilder::new(); + let builder = SystemPromptBuilder::new().with_personal_assistant_role(); let prompt = builder.build(); @@ -32,4 +32,22 @@ mod tests { assert!(prompt.contains("Your are a super hero.")); } + + #[test] + fn test_prompt_with_multiple_profiles() { + let profile1 = ProfileDto { + name: "some".to_string(), + prompt: "Your are a super hero.".to_string(), + }; + let profile2 = ProfileDto { + name: "other".to_string(), + prompt: "Be gentle".to_string(), + }; + let builder = SystemPromptBuilder::new().with_profiles(&vec![profile1, profile2]); + + let prompt = builder.build(); + + assert!(prompt.contains("Your are a super hero.")); + assert!(prompt.contains("Be gentle")); + } } From 2d777497e6c77a1a2bf7f759ead6a719c67b8cd0 Mon Sep 17 00:00:00 2001 From: Samuel Date: Wed, 2 Oct 2024 21:57:15 -0400 Subject: [PATCH 4/8] feat: inference using agent instead of llm --- .../domain/inference_service.rs | 103 ++++++++++++------ core/src/chat_completion/domain/mod.rs | 4 +- core/src/llm/domain/agent/agent_factory.rs | 8 ++ .../llm/domain/agent/base_agent_factory.rs | 2 +- 4 files changed, 80 insertions(+), 37 deletions(-) diff --git a/core/src/chat_completion/domain/inference_service.rs b/core/src/chat_completion/domain/inference_service.rs index 524de58..ff06e87 100644 --- a/core/src/chat_completion/domain/inference_service.rs +++ b/core/src/chat_completion/domain/inference_service.rs @@ -1,22 +1,21 @@ -use std::{error::Error, sync::Arc}; - -use async_stream::stream; - -use crate::llm::domain::llm_factory::{CreateLLMParameters, LLMFactory}; +use std::{collections::HashMap, error::Error, sync::Arc}; use super::{ dto::{ChatCompletionChunkObject, ChatCompletionMessageDto, ChatCompletionObject}, ChatCompletionResult, ChatCompletionStream, }; +use crate::llm::domain::agent::{base_agent_factory::CreateAgentArgs, AgentFactory}; +use async_stream::stream; use futures::StreamExt; +use langchain_rust::{chain::Chain, prompt_args}; pub struct InferenceService { - llm_factory: Arc, + agent_factory: Arc, } impl InferenceService { - pub fn new(llm_factory: Arc) -> Self { - Self { llm_factory } + pub fn new(agent_factory: Arc) -> Self { + Self { agent_factory } } pub async fn invoke( @@ -24,20 +23,14 @@ impl InferenceService { model: &str, messages: &Vec, ) -> ChatCompletionResult { - let messages: Vec = - messages.iter().map(|m| m.clone().into()).collect(); - let llm = self - .llm_factory - .create(&CreateLLMParameters { - model: model.to_string(), - ..CreateLLMParameters::default() - }) + let input_variables = Self::messages_to_input_variables(messages); + let chain = self + .create_chain("default", model) .await .map_err(|e| e as Box)?; + let result = chain.invoke(input_variables).await?; - let result = llm.generate(&messages[..]).await?; - - let message = ChatCompletionMessageDto::assistant(&result.generation); + let message = ChatCompletionMessageDto::assistant(&result); let data = ChatCompletionObject::new_single_choice(message, model); Ok(data) @@ -48,24 +41,22 @@ impl InferenceService { model: &str, messages: &Vec, ) -> ChatCompletionStream { - let messages: Vec = - messages.iter().map(|m| m.clone().into()).collect(); + let input_variables = Self::messages_to_input_variables(messages); + + let model: String = model.to_string(); + let self_clone = self.clone(); - let model = model.to_string(); - let llm_factory = Arc::clone(&self.llm_factory); let stream = stream! { - let llm = match llm_factory.create(&CreateLLMParameters { - model: model.clone(), - ..CreateLLMParameters::default() - }).await { - Ok(llm) => llm, - Err(e) => { - yield Err(e); - return; - } - }; + let chain = match self_clone.create_chain("default", &model) + .await { + Ok(chain) => chain, + Err(e) => { + yield Err(e); + return; + } + }; - let mut llm_stream = match llm.stream(&messages[..]).await { + let mut llm_stream = match chain.stream(input_variables).await { Ok(stream) => stream, Err(e) => { yield Err(Box::new(e)); @@ -89,4 +80,48 @@ impl InferenceService { Box::pin(stream) } + + async fn create_chain( + &self, + agent_id: &str, + model: &str, + ) -> Result, Box> { + let chain = self + .agent_factory + .create( + agent_id, + &CreateAgentArgs { + model: model.to_string(), + ..CreateAgentArgs::default() + }, + ) + .await?; + + Ok(chain) + } + + fn messages_to_string(messages: &Vec) -> String { + let messages: Vec = + messages.iter().map(|m| m.clone().into()).collect(); + + langchain_rust::schemas::Message::messages_to_string(&messages) + } + + fn messages_to_input_variables( + messages: &Vec, + ) -> HashMap { + let input = Self::messages_to_string(messages); + + prompt_args! { + "input" => input, + } + } +} + +impl Clone for InferenceService { + fn clone(&self) -> Self { + InferenceService { + agent_factory: Arc::clone(&self.agent_factory), + } + } } diff --git a/core/src/chat_completion/domain/mod.rs b/core/src/chat_completion/domain/mod.rs index 874f4d4..12017d4 100644 --- a/core/src/chat_completion/domain/mod.rs +++ b/core/src/chat_completion/domain/mod.rs @@ -21,8 +21,8 @@ impl ChatCompletionDIModule { } pub fn get_inference_factory(&self) -> Arc { - let llm_factory = self.llm_module.get_llm_factory(); + let agent_factory = self.llm_module.get_agent_factory(); - Arc::new(InferenceService::new(llm_factory)) + Arc::new(InferenceService::new(agent_factory)) } } diff --git a/core/src/llm/domain/agent/agent_factory.rs b/core/src/llm/domain/agent/agent_factory.rs index a4f5ad4..dadc270 100644 --- a/core/src/llm/domain/agent/agent_factory.rs +++ b/core/src/llm/domain/agent/agent_factory.rs @@ -41,3 +41,11 @@ impl AgentFactory { chain } } + +impl Clone for AgentFactory { + fn clone(&self) -> Self { + Self { + agent_factories: self.agent_factories.clone(), + } + } +} diff --git a/core/src/llm/domain/agent/base_agent_factory.rs b/core/src/llm/domain/agent/base_agent_factory.rs index 1d25a0f..6915b4d 100644 --- a/core/src/llm/domain/agent/base_agent_factory.rs +++ b/core/src/llm/domain/agent/base_agent_factory.rs @@ -11,7 +11,7 @@ pub struct CreateAgentArgs { #[automock] #[async_trait::async_trait] -pub trait BaseAgentFactory { +pub trait BaseAgentFactory: Send + Sync { fn is_compatible(&self, agent_id: &str) -> bool; async fn create(&self, args: &CreateAgentArgs) From e5e5711112bb4bc71802fd5e615579fc74210e49 Mon Sep 17 00:00:00 2001 From: Samuel Date: Thu, 3 Oct 2024 10:07:04 -0400 Subject: [PATCH 5/8] refact: split to allow agent inference --- core/src/api_facade.rs | 33 +++-- .../domain/stream_thread_run_service.rs | 2 +- .../thread_chat_completions_inference.rs | 24 +++- core/src/assistant/mod.rs | 6 +- .../chat_completion/domain/agent_inference.rs | 120 +++++++++++++++++ core/src/chat_completion/domain/inference.rs | 19 +++ .../domain/inference_service.rs | 127 ------------------ .../domain/langchain_adapter.rs | 13 ++ .../chat_completion/domain/llm_inference.rs | 91 +++++++++++++ core/src/chat_completion/domain/mod.rs | 25 +--- core/src/chat_completion/domain/types.rs | 2 +- core/src/chat_completion/mod.rs | 21 ++- core/src/llm/domain/agent/agent_factory.rs | 2 +- .../llm/domain/agent/agent_factory_test.rs | 2 +- .../llm/domain/agent/base_agent_factory.rs | 3 +- .../llm/domain/agent/default_agent_factory.rs | 9 +- .../src/controller/chat_completions.rs | 4 +- 17 files changed, 318 insertions(+), 185 deletions(-) create mode 100644 core/src/chat_completion/domain/agent_inference.rs create mode 100644 core/src/chat_completion/domain/inference.rs delete mode 100644 core/src/chat_completion/domain/inference_service.rs create mode 100644 core/src/chat_completion/domain/llm_inference.rs diff --git a/core/src/api_facade.rs b/core/src/api_facade.rs index d279973..260ed36 100644 --- a/core/src/api_facade.rs +++ b/core/src/api_facade.rs @@ -1,7 +1,10 @@ use std::sync::Arc; use crate::{ - chat_completion::{ChatCompletionMessageDto, ChatCompletionResult, ChatCompletionStream}, + chat_completion::{ + inference::{Inference, InferenceArgs}, + ChatCompletionMessageDto, ChatCompletionResult, ChatCompletionStream, + }, configuration::ConfigurationDto, profile::domain::dto::ProfileDto, AppContainer, @@ -21,25 +24,31 @@ impl ApiFacade { model: &str, messages: &Vec, ) -> ChatCompletionResult { - let factory = self - .container - .chat_completion_module - .get_inference_factory(); + let inference = self.container.chat_completion_module.get_llm_inference(); - factory.invoke(model, messages).await + inference + .invoke(InferenceArgs { + model: model.to_string(), + messages: messages.clone(), + ..Default::default() + }) + .await } - pub fn chat_completion_stream( + pub async fn chat_completion_stream( &self, model: &str, messages: &Vec, ) -> ChatCompletionStream { - let factory = self - .container - .chat_completion_module - .get_inference_factory(); + let inference = self.container.chat_completion_module.get_llm_inference(); - factory.stream(model, messages) + inference + .stream(InferenceArgs { + model: model.to_string(), + messages: messages.clone(), + ..Default::default() + }) + .await } pub async fn find_configuration( diff --git a/core/src/assistant/domain/stream_thread_run_service.rs b/core/src/assistant/domain/stream_thread_run_service.rs index 1eff0c8..c177ce9 100644 --- a/core/src/assistant/domain/stream_thread_run_service.rs +++ b/core/src/assistant/domain/stream_thread_run_service.rs @@ -157,7 +157,7 @@ impl StreamThreadRunService { yield ThreadEventDto::thread_message_created(&response_message); yield ThreadEventDto::thread_message_in_progress(&response_message); - let mut stream = inference_service.stream(&run.model, &messages); + let mut stream = inference_service.stream(&run.model, &messages).await; while let Some(chunk) = stream.next().await { let chunk = match chunk { Ok(chunk) => chunk, diff --git a/core/src/assistant/domain/thread_chat_completions_inference.rs b/core/src/assistant/domain/thread_chat_completions_inference.rs index c1a0241..712f9c8 100644 --- a/core/src/assistant/domain/thread_chat_completions_inference.rs +++ b/core/src/assistant/domain/thread_chat_completions_inference.rs @@ -1,22 +1,34 @@ use std::sync::Arc; -use crate::chat_completion::{self, ChatCompletionMessageDto, ChatCompletionStream}; +use crate::chat_completion::{ + self, inference::InferenceArgs, ChatCompletionMessageDto, ChatCompletionStream, +}; use super::dto::ThreadMessageDto; pub struct ThreadChatCompletionInference { - inference_service: Arc, + inference: Arc, } impl ThreadChatCompletionInference { - pub fn new(inference_service: Arc) -> Self { - Self { inference_service } + pub fn new(inference: Arc) -> Self { + Self { inference } } - pub fn stream(&self, model: &str, messages: &Vec) -> ChatCompletionStream { + pub async fn stream( + &self, + model: &str, + messages: &Vec, + ) -> ChatCompletionStream { let messages: Vec = messages.iter().map(|m| m.clone().into()).collect(); - self.inference_service.stream(model, &messages) + self.inference + .stream(InferenceArgs { + model: model.to_string(), + messages: messages, + ..Default::default() + }) + .await } } diff --git a/core/src/assistant/mod.rs b/core/src/assistant/mod.rs index ab1d3dd..3e4e063 100644 --- a/core/src/assistant/mod.rs +++ b/core/src/assistant/mod.rs @@ -74,11 +74,9 @@ impl AgentDIModule { } pub fn get_thread_inference_service(&self) -> Arc { - let chat_completion_inference_service = self.chat_completion_module.get_inference_factory(); + let inference = self.chat_completion_module.get_llm_inference(); - Arc::new(ThreadChatCompletionInference::new( - chat_completion_inference_service, - )) + Arc::new(ThreadChatCompletionInference::new(inference)) } pub fn get_run_status_mutator(&self) -> Arc { diff --git a/core/src/chat_completion/domain/agent_inference.rs b/core/src/chat_completion/domain/agent_inference.rs new file mode 100644 index 0000000..8a8819b --- /dev/null +++ b/core/src/chat_completion/domain/agent_inference.rs @@ -0,0 +1,120 @@ +use anyhow::anyhow; +use async_stream::stream; +use futures::StreamExt; +use langchain_rust::{chain::Chain, prompt_args}; +use std::{error::Error, sync::Arc}; + +use crate::{ + chat_completion::domain::langchain_adapter::{ + langchain_messages_to_string, messages_to_langchain_messages, + }, + llm::domain::agent::{base_agent_factory::CreateAgentArgs, AgentFactory}, +}; + +use super::{ + inference::{Inference, InferenceArgs}, + ChatCompletionChunkObject, ChatCompletionMessageDto, ChatCompletionObject, + ChatCompletionResult, ChatCompletionStream, +}; + +pub struct AgentInference { + agent_factory: Arc, +} + +impl AgentInference { + pub fn new(agent_factory: Arc) -> Self { + Self { agent_factory } + } + + fn get_input_varialbes( + messages: &[ChatCompletionMessageDto], + ) -> std::collections::HashMap { + let messages = messages_to_langchain_messages(messages); + let input = langchain_messages_to_string(&messages); + + prompt_args! { + "input" => input, + } + } + + async fn get_agent( + &self, + args: &InferenceArgs, + ) -> Result, Box> { + self.agent_factory + .create( + "default", + CreateAgentArgs { + model: args.model.to_string(), + temperature: args.temperature, + ..Default::default() + }, + ) + .await + } +} + +#[async_trait::async_trait] +impl Inference for AgentInference { + async fn invoke(&self, args: InferenceArgs) -> ChatCompletionResult { + let agent = self.get_agent(&args).await?; + let input_variables = Self::get_input_varialbes(&args.messages); + + let result = agent + .invoke(input_variables) + .await + .map_err(|e| anyhow!(e))?; + + let message = ChatCompletionMessageDto::assistant(&result); + let data = ChatCompletionObject::new_single_choice(message, &args.model); + + Ok(data) + } + + async fn stream(&self, args: InferenceArgs) -> ChatCompletionStream { + let input_variables = Self::get_input_varialbes(&args.messages); + let model = args.model.to_string(); + let self_clone = self.clone(); + + let stream = stream! { + let agent = match self_clone.get_agent(&args).await { + Ok(agent) => agent, + Err(e) => { + yield Err(e); + return; + } + }; + + let mut agent_stream = match agent.stream(input_variables).await { + Ok(stream) => stream, + Err(e) => { + yield Err(Box::new(e)); + return; + } + }; + + while let Some(chunk) = agent_stream.next().await { + let chunk = match chunk { + Ok(chunk) => chunk, + Err(e) => { + yield Err(Box::new(e)); + return; + } + }; + let chunk = ChatCompletionChunkObject::new_assistant_chunk(&chunk.content, &model); + + yield Ok(chunk); + } + }; + + Box::pin(stream) + } +} + +impl Clone for AgentInference { + fn clone(&self) -> Self { + Self { + agent_factory: Arc::clone(&self.agent_factory), + } + } +} diff --git a/core/src/chat_completion/domain/inference.rs b/core/src/chat_completion/domain/inference.rs new file mode 100644 index 0000000..a4d8383 --- /dev/null +++ b/core/src/chat_completion/domain/inference.rs @@ -0,0 +1,19 @@ +use mockall::automock; + +use crate::chat_completion::{ + ChatCompletionMessageDto, ChatCompletionResult, ChatCompletionStream, +}; + +#[derive(Default)] +pub struct InferenceArgs { + pub model: String, + pub temperature: Option, + pub messages: Vec, +} + +#[automock] +#[async_trait::async_trait] +pub trait Inference: Send + Sync { + async fn invoke(&self, args: InferenceArgs) -> ChatCompletionResult; + async fn stream(&self, args: InferenceArgs) -> ChatCompletionStream; +} diff --git a/core/src/chat_completion/domain/inference_service.rs b/core/src/chat_completion/domain/inference_service.rs deleted file mode 100644 index ff06e87..0000000 --- a/core/src/chat_completion/domain/inference_service.rs +++ /dev/null @@ -1,127 +0,0 @@ -use std::{collections::HashMap, error::Error, sync::Arc}; - -use super::{ - dto::{ChatCompletionChunkObject, ChatCompletionMessageDto, ChatCompletionObject}, - ChatCompletionResult, ChatCompletionStream, -}; -use crate::llm::domain::agent::{base_agent_factory::CreateAgentArgs, AgentFactory}; -use async_stream::stream; -use futures::StreamExt; -use langchain_rust::{chain::Chain, prompt_args}; - -pub struct InferenceService { - agent_factory: Arc, -} - -impl InferenceService { - pub fn new(agent_factory: Arc) -> Self { - Self { agent_factory } - } - - pub async fn invoke( - &self, - model: &str, - messages: &Vec, - ) -> ChatCompletionResult { - let input_variables = Self::messages_to_input_variables(messages); - let chain = self - .create_chain("default", model) - .await - .map_err(|e| e as Box)?; - let result = chain.invoke(input_variables).await?; - - let message = ChatCompletionMessageDto::assistant(&result); - let data = ChatCompletionObject::new_single_choice(message, model); - - Ok(data) - } - - pub fn stream( - &self, - model: &str, - messages: &Vec, - ) -> ChatCompletionStream { - let input_variables = Self::messages_to_input_variables(messages); - - let model: String = model.to_string(); - let self_clone = self.clone(); - - let stream = stream! { - let chain = match self_clone.create_chain("default", &model) - .await { - Ok(chain) => chain, - Err(e) => { - yield Err(e); - return; - } - }; - - let mut llm_stream = match chain.stream(input_variables).await { - Ok(stream) => stream, - Err(e) => { - yield Err(Box::new(e)); - return; - } - }; - - while let Some(chunk) = llm_stream.next().await { - let chunk = match chunk { - Ok(chunk) => chunk, - Err(e) => { - yield Err(Box::new(e)); - return; - } - }; - let chunk = ChatCompletionChunkObject::new_assistant_chunk(&chunk.content, &model); - - yield Ok(chunk); - } - }; - - Box::pin(stream) - } - - async fn create_chain( - &self, - agent_id: &str, - model: &str, - ) -> Result, Box> { - let chain = self - .agent_factory - .create( - agent_id, - &CreateAgentArgs { - model: model.to_string(), - ..CreateAgentArgs::default() - }, - ) - .await?; - - Ok(chain) - } - - fn messages_to_string(messages: &Vec) -> String { - let messages: Vec = - messages.iter().map(|m| m.clone().into()).collect(); - - langchain_rust::schemas::Message::messages_to_string(&messages) - } - - fn messages_to_input_variables( - messages: &Vec, - ) -> HashMap { - let input = Self::messages_to_string(messages); - - prompt_args! { - "input" => input, - } - } -} - -impl Clone for InferenceService { - fn clone(&self) -> Self { - InferenceService { - agent_factory: Arc::clone(&self.agent_factory), - } - } -} diff --git a/core/src/chat_completion/domain/langchain_adapter.rs b/core/src/chat_completion/domain/langchain_adapter.rs index f4b1c57..8df91eb 100644 --- a/core/src/chat_completion/domain/langchain_adapter.rs +++ b/core/src/chat_completion/domain/langchain_adapter.rs @@ -40,3 +40,16 @@ impl From<&ImageUrl> for langchain_rust::schemas::ImageContent { } } } + +pub fn messages_to_langchain_messages( + messages: &[ChatCompletionMessageDto], +) -> Vec { + let messages: Vec = + messages.iter().map(|m| m.clone().into()).collect(); + + messages +} + +pub fn langchain_messages_to_string(messages: &[langchain_rust::schemas::Message]) -> String { + langchain_rust::schemas::Message::messages_to_string(messages) +} diff --git a/core/src/chat_completion/domain/llm_inference.rs b/core/src/chat_completion/domain/llm_inference.rs new file mode 100644 index 0000000..66a96de --- /dev/null +++ b/core/src/chat_completion/domain/llm_inference.rs @@ -0,0 +1,91 @@ +use std::sync::Arc; + +use crate::{ + chat_completion::{ChatCompletionMessageDto, ChatCompletionObject}, + llm::domain::llm_factory::{CreateLLMParameters, LLMFactory}, +}; +use anyhow::anyhow; +use async_stream::stream; +use futures::StreamExt; + +use super::{ + inference::{Inference, InferenceArgs}, + langchain_adapter::messages_to_langchain_messages, + ChatCompletionChunkObject, ChatCompletionResult, ChatCompletionStream, +}; + +pub struct LLMInference { + llm_factory: Arc, +} + +impl LLMInference { + pub fn new(llm_factory: Arc) -> Self { + Self { llm_factory } + } +} + +#[async_trait::async_trait] +impl Inference for LLMInference { + async fn invoke(&self, args: InferenceArgs) -> ChatCompletionResult { + let model = &args.model; + let messages = messages_to_langchain_messages(&args.messages); + + let llm = self + .llm_factory + .create(&CreateLLMParameters { + model: model.to_string(), + temperature: args.temperature, + }) + .await?; + + let result = llm.generate(&messages[..]).await.map_err(|e| anyhow!(e))?; + + let message = ChatCompletionMessageDto::assistant(&result.generation); + let data = ChatCompletionObject::new_single_choice(message, model); + + Ok(data) + } + + async fn stream(&self, args: InferenceArgs) -> ChatCompletionStream { + let messages: Vec = + args.messages.iter().map(|m| m.clone().into()).collect(); + + let model = args.model.to_string(); + let llm_factory = Arc::clone(&self.llm_factory); + let stream = stream! { + let llm = match llm_factory.create(&CreateLLMParameters { + model: model.clone(), + ..CreateLLMParameters::default() + }).await { + Ok(llm) => llm, + Err(e) => { + yield Err(e); + return; + } + }; + + let mut llm_stream = match llm.stream(&messages[..]).await { + Ok(stream) => stream, + Err(e) => { + yield Err(Box::new(e)); + return; + } + }; + + while let Some(chunk) = llm_stream.next().await { + let chunk = match chunk { + Ok(chunk) => chunk, + Err(e) => { + yield Err(Box::new(e)); + return; + } + }; + let chunk = ChatCompletionChunkObject::new_assistant_chunk(&chunk.content, &model); + + yield Ok(chunk); + } + }; + + Box::pin(stream) + } +} diff --git a/core/src/chat_completion/domain/mod.rs b/core/src/chat_completion/domain/mod.rs index 12017d4..1eb8347 100644 --- a/core/src/chat_completion/domain/mod.rs +++ b/core/src/chat_completion/domain/mod.rs @@ -1,28 +1,9 @@ +pub mod agent_inference; pub mod dto; -mod inference_service; +pub mod inference; mod langchain_adapter; +pub mod llm_inference; mod types; -use std::sync::Arc; - pub use dto::*; -pub use inference_service::InferenceService; pub use types::*; - -use crate::llm::LLMDIModule; - -pub struct ChatCompletionDIModule { - llm_module: Arc, -} - -impl ChatCompletionDIModule { - pub fn new(llm_module: Arc) -> Self { - Self { llm_module } - } - - pub fn get_inference_factory(&self) -> Arc { - let agent_factory = self.llm_module.get_agent_factory(); - - Arc::new(InferenceService::new(agent_factory)) - } -} diff --git a/core/src/chat_completion/domain/types.rs b/core/src/chat_completion/domain/types.rs index ff77c96..2e877ce 100644 --- a/core/src/chat_completion/domain/types.rs +++ b/core/src/chat_completion/domain/types.rs @@ -4,6 +4,6 @@ use futures::Stream; use super::dto::{ChatCompletionChunkObject, ChatCompletionObject}; -pub type ChatCompletionResult = Result>; +pub type ChatCompletionResult = Result>; pub type ChatCompletionStream = Pin>> + Send>>; diff --git a/core/src/chat_completion/mod.rs b/core/src/chat_completion/mod.rs index 06d902b..abdce2f 100644 --- a/core/src/chat_completion/mod.rs +++ b/core/src/chat_completion/mod.rs @@ -1,4 +1,23 @@ mod domain; + pub use domain::*; +use llm_inference::LLMInference; +use std::sync::Arc; + +use crate::llm::LLMDIModule; + +pub struct ChatCompletionDIModule { + llm_module: Arc, +} + +impl ChatCompletionDIModule { + pub fn new(llm_module: Arc) -> Self { + Self { llm_module } + } + + pub fn get_llm_inference(&self) -> Arc { + let llm_factory = self.llm_module.get_llm_factory(); -pub use domain::InferenceService; + Arc::new(LLMInference::new(llm_factory)) + } +} diff --git a/core/src/llm/domain/agent/agent_factory.rs b/core/src/llm/domain/agent/agent_factory.rs index dadc270..367b155 100644 --- a/core/src/llm/domain/agent/agent_factory.rs +++ b/core/src/llm/domain/agent/agent_factory.rs @@ -31,7 +31,7 @@ impl AgentFactory { pub async fn create( &self, agent_id: &str, - args: &CreateAgentArgs, + args: CreateAgentArgs, ) -> Result, Box> { let chain = match self.find_factory(agent_id) { Some(factory) => factory.create(args).await, diff --git a/core/src/llm/domain/agent/agent_factory_test.rs b/core/src/llm/domain/agent/agent_factory_test.rs index 530c3cd..e32152f 100644 --- a/core/src/llm/domain/agent/agent_factory_test.rs +++ b/core/src/llm/domain/agent/agent_factory_test.rs @@ -55,7 +55,7 @@ mod test { let instance = AgentFactory::new(vec![Arc::new(factory1), Arc::new(factory2)]); let result = instance - .create("AAA", &CreateAgentArgs::default()) + .create("AAA", CreateAgentArgs::default()) .await .unwrap(); let gen_result = result.call(HashMap::new()).await.unwrap(); diff --git a/core/src/llm/domain/agent/base_agent_factory.rs b/core/src/llm/domain/agent/base_agent_factory.rs index 6915b4d..4c697e0 100644 --- a/core/src/llm/domain/agent/base_agent_factory.rs +++ b/core/src/llm/domain/agent/base_agent_factory.rs @@ -14,6 +14,5 @@ pub struct CreateAgentArgs { pub trait BaseAgentFactory: Send + Sync { fn is_compatible(&self, agent_id: &str) -> bool; - async fn create(&self, args: &CreateAgentArgs) - -> Result, Box>; + async fn create(&self, args: CreateAgentArgs) -> Result, Box>; } diff --git a/core/src/llm/domain/agent/default_agent_factory.rs b/core/src/llm/domain/agent/default_agent_factory.rs index a0d727b..c459de9 100644 --- a/core/src/llm/domain/agent/default_agent_factory.rs +++ b/core/src/llm/domain/agent/default_agent_factory.rs @@ -45,10 +45,7 @@ impl BaseAgentFactory for DefaultAgentFactory { true } - async fn create( - &self, - args: &CreateAgentArgs, - ) -> Result, Box> { + async fn create(&self, args: CreateAgentArgs) -> Result, Box> { let system_prompt = self.system_prompt_factory.create().await?; let llm = self .llm_factory @@ -61,11 +58,13 @@ impl BaseAgentFactory for DefaultAgentFactory { let agent = ConversationalAgentBuilder::new() // .tools(&[Arc::new(command_executor)]) - .options(Self::create_options(args)) + .options(Self::create_options(&args)) .prefix(system_prompt) .build(llm) .unwrap(); + + let executor = AgentExecutor::from_agent(agent); Ok(Box::new(executor)) diff --git a/inference_server/src/controller/chat_completions.rs b/inference_server/src/controller/chat_completions.rs index 18e066a..b8b274b 100644 --- a/inference_server/src/controller/chat_completions.rs +++ b/inference_server/src/controller/chat_completions.rs @@ -43,7 +43,7 @@ pub async fn run_chat_completions( async fn run_json_chat_completions( state: Arc, payload: ApiChatCompletionRequestDto, -) -> Result, Box> { +) -> Result, Box> { let result = state .api .chat_completion_invoke(&payload.model, &payload.messages) @@ -59,7 +59,7 @@ fn run_stream_chat_completions( let model = payload.model.clone(); Sse::new(try_stream! { - let mut stream = state.api.chat_completion_stream(&model, &payload.messages); + let mut stream = state.api.chat_completion_stream(&model, &payload.messages).await; while let Some(chunk) = stream.next().await { match chunk { Ok(chunk) => { From 29e64fdbecaab336a89f1ee4085a065cbe3ee413 Mon Sep 17 00:00:00 2001 From: Samuel Date: Thu, 3 Oct 2024 10:35:02 -0400 Subject: [PATCH 6/8] feat: add profile --- core/src/api_facade.rs | 4 +- core/src/app_container.rs | 5 +- core/src/assistant/mod.rs | 2 +- .../chat_completion/domain/dto/completion.rs | 7 ++ core/src/chat_completion/domain/inference.rs | 2 +- core/src/chat_completion/domain/mod.rs | 2 +- .../domain/profiled_inference_factory.rs | 83 +++++++++++++++++++ core/src/chat_completion/mod.rs | 22 ++++- core/src/profile/domain/mod.rs | 2 + .../domain/profile_system_prompt_factory.rs | 29 +++++++ core/src/profile/mod.rs | 10 ++- 11 files changed, 158 insertions(+), 10 deletions(-) create mode 100644 core/src/chat_completion/domain/profiled_inference_factory.rs create mode 100644 core/src/profile/domain/profile_system_prompt_factory.rs diff --git a/core/src/api_facade.rs b/core/src/api_facade.rs index 260ed36..4681ffe 100644 --- a/core/src/api_facade.rs +++ b/core/src/api_facade.rs @@ -24,7 +24,7 @@ impl ApiFacade { model: &str, messages: &Vec, ) -> ChatCompletionResult { - let inference = self.container.chat_completion_module.get_llm_inference(); + let inference = self.container.chat_completion_module.get_inference(); inference .invoke(InferenceArgs { @@ -40,7 +40,7 @@ impl ApiFacade { model: &str, messages: &Vec, ) -> ChatCompletionStream { - let inference = self.container.chat_completion_module.get_llm_inference(); + let inference = self.container.chat_completion_module.get_inference(); inference .stream(InferenceArgs { diff --git a/core/src/app_container.rs b/core/src/app_container.rs index 6da41a3..0713be4 100644 --- a/core/src/app_container.rs +++ b/core/src/app_container.rs @@ -27,7 +27,10 @@ impl AppContainer { configuration_module.clone(), profile_module.clone(), )); - let chat_completion_module = Arc::new(ChatCompletionDIModule::new(Arc::clone(&llm_module))); + let chat_completion_module = Arc::new(ChatCompletionDIModule::new( + Arc::clone(&llm_module), + Arc::clone(&profile_module), + )); let agent_module: AgentDIModule = AgentDIModule::new(Arc::clone(&connection), Arc::clone(&chat_completion_module)); diff --git a/core/src/assistant/mod.rs b/core/src/assistant/mod.rs index 3e4e063..133c415 100644 --- a/core/src/assistant/mod.rs +++ b/core/src/assistant/mod.rs @@ -74,7 +74,7 @@ impl AgentDIModule { } pub fn get_thread_inference_service(&self) -> Arc { - let inference = self.chat_completion_module.get_llm_inference(); + let inference = self.chat_completion_module.get_inference(); Arc::new(ThreadChatCompletionInference::new(inference)) } diff --git a/core/src/chat_completion/domain/dto/completion.rs b/core/src/chat_completion/domain/dto/completion.rs index f06f2a4..2bd9637 100644 --- a/core/src/chat_completion/domain/dto/completion.rs +++ b/core/src/chat_completion/domain/dto/completion.rs @@ -35,6 +35,13 @@ impl ChatCompletionMessageDto { } } + pub fn system(content: &str) -> Self { + ChatCompletionMessageDto { + role: "system".to_string(), + content: vec![ApiMessageContent::text(content)], + } + } + pub fn to_string_content(&self) -> String { self.content .iter() diff --git a/core/src/chat_completion/domain/inference.rs b/core/src/chat_completion/domain/inference.rs index a4d8383..152035a 100644 --- a/core/src/chat_completion/domain/inference.rs +++ b/core/src/chat_completion/domain/inference.rs @@ -4,7 +4,7 @@ use crate::chat_completion::{ ChatCompletionMessageDto, ChatCompletionResult, ChatCompletionStream, }; -#[derive(Default)] +#[derive(Default, Clone)] pub struct InferenceArgs { pub model: String, pub temperature: Option, diff --git a/core/src/chat_completion/domain/mod.rs b/core/src/chat_completion/domain/mod.rs index 1eb8347..a65a726 100644 --- a/core/src/chat_completion/domain/mod.rs +++ b/core/src/chat_completion/domain/mod.rs @@ -3,7 +3,7 @@ pub mod dto; pub mod inference; mod langchain_adapter; pub mod llm_inference; +pub mod profiled_inference_factory; mod types; - pub use dto::*; pub use types::*; diff --git a/core/src/chat_completion/domain/profiled_inference_factory.rs b/core/src/chat_completion/domain/profiled_inference_factory.rs new file mode 100644 index 0000000..57c49db --- /dev/null +++ b/core/src/chat_completion/domain/profiled_inference_factory.rs @@ -0,0 +1,83 @@ +use std::{error::Error, sync::Arc}; + +use async_stream::stream; + +use super::{ + inference::{Inference, InferenceArgs}, + ChatCompletionMessageDto, ChatCompletionResult, ChatCompletionStream, +}; +use crate::profile::domain::ProfileSystemPromptFactory; +use futures::StreamExt; + +pub struct ProfiledInferenceFactory { + inference: Arc, + system_prompt_factory: Arc, +} + +impl ProfiledInferenceFactory { + pub fn new( + inference: Arc, + system_prompt_factory: Arc, + ) -> Self { + Self { + inference, + system_prompt_factory, + } + } + + async fn args_with_system_prompt( + &self, + args: InferenceArgs, + ) -> Result> { + let system_prompt = self.system_prompt_factory.create_system_prompt().await?; + let system_message = ChatCompletionMessageDto::system(&system_prompt); + + let mut messages = args.messages.clone(); + messages.insert(0, system_message); + + Ok(InferenceArgs { + messages, + ..args.clone() + }) + } +} + +#[async_trait::async_trait] +impl Inference for ProfiledInferenceFactory { + async fn invoke(&self, args: InferenceArgs) -> ChatCompletionResult { + let args_with_system_prompt = self.args_with_system_prompt(args.clone()).await?; + self.inference.invoke(args_with_system_prompt).await + } + + async fn stream(&self, args: InferenceArgs) -> ChatCompletionStream { + let inference = self.inference.clone(); + let self_clone = self.clone(); + + let stream = stream! { + let args_with_system_prompt = match self_clone.args_with_system_prompt(args.clone()).await { + Ok(args_with_system_prompt) => args_with_system_prompt, + Err(e) => { + yield Err(e); + return; + } + }; + + let mut inference_stream = inference.stream(args_with_system_prompt).await; + + while let Some(chunk) = inference_stream.next().await { + yield chunk; + } + }; + + Box::pin(stream) + } +} + +impl Clone for ProfiledInferenceFactory { + fn clone(&self) -> Self { + Self { + inference: Arc::clone(&self.inference), + system_prompt_factory: Arc::clone(&self.system_prompt_factory), + } + } +} diff --git a/core/src/chat_completion/mod.rs b/core/src/chat_completion/mod.rs index abdce2f..c614990 100644 --- a/core/src/chat_completion/mod.rs +++ b/core/src/chat_completion/mod.rs @@ -1,18 +1,24 @@ mod domain; pub use domain::*; +use inference::Inference; use llm_inference::LLMInference; +use profiled_inference_factory::ProfiledInferenceFactory; use std::sync::Arc; -use crate::llm::LLMDIModule; +use crate::{llm::LLMDIModule, profile::ProfileDIModule}; pub struct ChatCompletionDIModule { llm_module: Arc, + profile_module: Arc, } impl ChatCompletionDIModule { - pub fn new(llm_module: Arc) -> Self { - Self { llm_module } + pub fn new(llm_module: Arc, profile_module: Arc) -> Self { + Self { + llm_module, + profile_module, + } } pub fn get_llm_inference(&self) -> Arc { @@ -20,4 +26,14 @@ impl ChatCompletionDIModule { Arc::new(LLMInference::new(llm_factory)) } + + pub fn get_inference(&self) -> Arc { + let llm_inference: Arc = self.get_llm_inference(); + let system_prompt_factory = self.profile_module.get_profile_system_prompt_factory(); + + Arc::new(ProfiledInferenceFactory::new( + llm_inference, + system_prompt_factory, + )) + } } diff --git a/core/src/profile/domain/mod.rs b/core/src/profile/domain/mod.rs index fc04c2e..d0fe0dd 100644 --- a/core/src/profile/domain/mod.rs +++ b/core/src/profile/domain/mod.rs @@ -1,9 +1,11 @@ mod computer_info_service; pub mod dto; mod profile_repository; +mod profile_system_prompt_factory; mod selected_profile_service; mod system_prompt_builder; pub use profile_repository::ProfileRepository; +pub use profile_system_prompt_factory::ProfileSystemPromptFactory; pub use selected_profile_service::SelectedProfileService; pub use system_prompt_builder::SystemPromptBuilder; diff --git a/core/src/profile/domain/profile_system_prompt_factory.rs b/core/src/profile/domain/profile_system_prompt_factory.rs new file mode 100644 index 0000000..7aa8a84 --- /dev/null +++ b/core/src/profile/domain/profile_system_prompt_factory.rs @@ -0,0 +1,29 @@ +use std::{error::Error, sync::Arc}; + +use super::{SelectedProfileService, SystemPromptBuilder}; + +pub struct ProfileSystemPromptFactory { + selected_profile_service: Arc, +} + +impl ProfileSystemPromptFactory { + pub fn new(selected_profile_service: Arc) -> Self { + Self { + selected_profile_service, + } + } + + pub async fn create_system_prompt(&self) -> Result> { + let profiles = self + .selected_profile_service + .find_selected_profiles() + .await?; + + let prompt = SystemPromptBuilder::new() + .with_computer_info() + .with_profiles(&profiles) + .build(); + + Ok(prompt) + } +} diff --git a/core/src/profile/mod.rs b/core/src/profile/mod.rs index f607176..0aae9ed 100644 --- a/core/src/profile/mod.rs +++ b/core/src/profile/mod.rs @@ -1,6 +1,6 @@ use std::sync::Arc; -use domain::{ProfileRepository, SelectedProfileService}; +use domain::{ProfileRepository, ProfileSystemPromptFactory, SelectedProfileService}; use infrastructure::SeaOrmProfileRepository; pub mod domain; @@ -28,4 +28,12 @@ impl ProfileDIModule { let profile_repository = self.get_profile_repository(); Arc::new(SelectedProfileService::new(Arc::clone(&profile_repository))) } + + pub fn get_profile_system_prompt_factory(&self) -> Arc { + let selected_profile_service = self.get_selected_profiles_service(); + + Arc::new(ProfileSystemPromptFactory::new( + selected_profile_service.clone(), + )) + } } From 10f5bf2941b2861eb98fbb7b45179725b2f90ca0 Mon Sep 17 00:00:00 2001 From: Samuel Date: Thu, 3 Oct 2024 10:39:27 -0400 Subject: [PATCH 7/8] refact: move inference to dedicated module --- core/src/api_facade.rs | 4 ++-- .../domain/{ => inference}/agent_inference.rs | 12 +++++------- .../domain/{ => inference}/inference.rs | 0 .../domain/{ => inference}/llm_inference.rs | 11 +++++------ core/src/chat_completion/domain/inference/mod.rs | 9 +++++++++ .../{ => inference}/profiled_inference_factory.rs | 8 ++++---- core/src/chat_completion/domain/mod.rs | 3 --- core/src/chat_completion/mod.rs | 4 +--- 8 files changed, 26 insertions(+), 25 deletions(-) rename core/src/chat_completion/domain/{ => inference}/agent_inference.rs (91%) rename core/src/chat_completion/domain/{ => inference}/inference.rs (100%) rename core/src/chat_completion/domain/{ => inference}/llm_inference.rs (90%) create mode 100644 core/src/chat_completion/domain/inference/mod.rs rename core/src/chat_completion/domain/{ => inference}/profiled_inference_factory.rs (91%) diff --git a/core/src/api_facade.rs b/core/src/api_facade.rs index 4681ffe..1a0d56a 100644 --- a/core/src/api_facade.rs +++ b/core/src/api_facade.rs @@ -2,8 +2,8 @@ use std::sync::Arc; use crate::{ chat_completion::{ - inference::{Inference, InferenceArgs}, - ChatCompletionMessageDto, ChatCompletionResult, ChatCompletionStream, + inference::InferenceArgs, ChatCompletionMessageDto, ChatCompletionResult, + ChatCompletionStream, }, configuration::ConfigurationDto, profile::domain::dto::ProfileDto, diff --git a/core/src/chat_completion/domain/agent_inference.rs b/core/src/chat_completion/domain/inference/agent_inference.rs similarity index 91% rename from core/src/chat_completion/domain/agent_inference.rs rename to core/src/chat_completion/domain/inference/agent_inference.rs index 8a8819b..57efb05 100644 --- a/core/src/chat_completion/domain/agent_inference.rs +++ b/core/src/chat_completion/domain/inference/agent_inference.rs @@ -5,17 +5,15 @@ use langchain_rust::{chain::Chain, prompt_args}; use std::{error::Error, sync::Arc}; use crate::{ - chat_completion::domain::langchain_adapter::{ - langchain_messages_to_string, messages_to_langchain_messages, + chat_completion::{ + domain::langchain_adapter::{langchain_messages_to_string, messages_to_langchain_messages}, + ChatCompletionChunkObject, ChatCompletionMessageDto, ChatCompletionObject, + ChatCompletionResult, ChatCompletionStream, }, llm::domain::agent::{base_agent_factory::CreateAgentArgs, AgentFactory}, }; -use super::{ - inference::{Inference, InferenceArgs}, - ChatCompletionChunkObject, ChatCompletionMessageDto, ChatCompletionObject, - ChatCompletionResult, ChatCompletionStream, -}; +use super::inference::{Inference, InferenceArgs}; pub struct AgentInference { agent_factory: Arc, diff --git a/core/src/chat_completion/domain/inference.rs b/core/src/chat_completion/domain/inference/inference.rs similarity index 100% rename from core/src/chat_completion/domain/inference.rs rename to core/src/chat_completion/domain/inference/inference.rs diff --git a/core/src/chat_completion/domain/llm_inference.rs b/core/src/chat_completion/domain/inference/llm_inference.rs similarity index 90% rename from core/src/chat_completion/domain/llm_inference.rs rename to core/src/chat_completion/domain/inference/llm_inference.rs index 66a96de..d8f53b8 100644 --- a/core/src/chat_completion/domain/llm_inference.rs +++ b/core/src/chat_completion/domain/inference/llm_inference.rs @@ -1,18 +1,17 @@ use std::sync::Arc; use crate::{ - chat_completion::{ChatCompletionMessageDto, ChatCompletionObject}, + chat_completion::{ + domain::langchain_adapter::messages_to_langchain_messages, ChatCompletionChunkObject, + ChatCompletionMessageDto, ChatCompletionObject, ChatCompletionResult, ChatCompletionStream, + }, llm::domain::llm_factory::{CreateLLMParameters, LLMFactory}, }; use anyhow::anyhow; use async_stream::stream; use futures::StreamExt; -use super::{ - inference::{Inference, InferenceArgs}, - langchain_adapter::messages_to_langchain_messages, - ChatCompletionChunkObject, ChatCompletionResult, ChatCompletionStream, -}; +use super::inference::{Inference, InferenceArgs}; pub struct LLMInference { llm_factory: Arc, diff --git a/core/src/chat_completion/domain/inference/mod.rs b/core/src/chat_completion/domain/inference/mod.rs new file mode 100644 index 0000000..6059f04 --- /dev/null +++ b/core/src/chat_completion/domain/inference/mod.rs @@ -0,0 +1,9 @@ +mod agent_inference; +mod inference; +mod llm_inference; +mod profiled_inference_factory; + +pub use agent_inference::AgentInference; +pub use inference::{Inference, InferenceArgs}; +pub use llm_inference::LLMInference; +pub use profiled_inference_factory::ProfiledInferenceFactory; diff --git a/core/src/chat_completion/domain/profiled_inference_factory.rs b/core/src/chat_completion/domain/inference/profiled_inference_factory.rs similarity index 91% rename from core/src/chat_completion/domain/profiled_inference_factory.rs rename to core/src/chat_completion/domain/inference/profiled_inference_factory.rs index 57c49db..50fe91c 100644 --- a/core/src/chat_completion/domain/profiled_inference_factory.rs +++ b/core/src/chat_completion/domain/inference/profiled_inference_factory.rs @@ -2,11 +2,11 @@ use std::{error::Error, sync::Arc}; use async_stream::stream; -use super::{ - inference::{Inference, InferenceArgs}, - ChatCompletionMessageDto, ChatCompletionResult, ChatCompletionStream, +use super::inference::{Inference, InferenceArgs}; +use crate::{ + chat_completion::{ChatCompletionMessageDto, ChatCompletionResult, ChatCompletionStream}, + profile::domain::ProfileSystemPromptFactory, }; -use crate::profile::domain::ProfileSystemPromptFactory; use futures::StreamExt; pub struct ProfiledInferenceFactory { diff --git a/core/src/chat_completion/domain/mod.rs b/core/src/chat_completion/domain/mod.rs index a65a726..7587f33 100644 --- a/core/src/chat_completion/domain/mod.rs +++ b/core/src/chat_completion/domain/mod.rs @@ -1,9 +1,6 @@ -pub mod agent_inference; pub mod dto; pub mod inference; mod langchain_adapter; -pub mod llm_inference; -pub mod profiled_inference_factory; mod types; pub use dto::*; pub use types::*; diff --git a/core/src/chat_completion/mod.rs b/core/src/chat_completion/mod.rs index c614990..6231235 100644 --- a/core/src/chat_completion/mod.rs +++ b/core/src/chat_completion/mod.rs @@ -1,9 +1,7 @@ mod domain; pub use domain::*; -use inference::Inference; -use llm_inference::LLMInference; -use profiled_inference_factory::ProfiledInferenceFactory; +use inference::{Inference, LLMInference, ProfiledInferenceFactory}; use std::sync::Arc; use crate::{llm::LLMDIModule, profile::ProfileDIModule}; From 1fea86168650d9ba324056eca39f28fe511d916a Mon Sep 17 00:00:00 2001 From: Samuel Date: Thu, 3 Oct 2024 10:47:08 -0400 Subject: [PATCH 8/8] feat: display selected profile --- desktop/src/core/tauri_command.rs | 11 ++++++++++- desktop/src/main.rs | 3 ++- .../_components/inference-selected-profiles.tsx | 14 ++++++++++++++ .../settings/_components/internal-info.section.tsx | 2 ++ webapp/src/app/settings/page.tsx | 6 ++---- webapp/src/lib/core-api/tauri/index.ts | 3 ++- webapp/src/lib/core-api/tauri/profile.ts | 10 ++++++++++ 7 files changed, 42 insertions(+), 7 deletions(-) create mode 100644 webapp/src/app/settings/_components/inference-selected-profiles.tsx create mode 100644 webapp/src/lib/core-api/tauri/profile.ts diff --git a/desktop/src/core/tauri_command.rs b/desktop/src/core/tauri_command.rs index 6a8b49a..92983cc 100644 --- a/desktop/src/core/tauri_command.rs +++ b/desktop/src/core/tauri_command.rs @@ -1,5 +1,5 @@ use crate::app_state::app_state::AppState; -use app_core::configuration::ConfigurationDto; +use app_core::{configuration::ConfigurationDto, profile::domain::dto::ProfileDto}; use tauri::State; #[tauri::command] @@ -44,3 +44,12 @@ pub async fn upsert_configuration( .await .map_err(|err| err.to_string()) } + +#[tauri::command] +pub async fn selected_profiles(app_state: State<'_, AppState>) -> Result, String> { + app_state + .api + .get_selected_profiles() + .await + .map_err(|err| err.to_string()) +} diff --git a/desktop/src/main.rs b/desktop/src/main.rs index 2f2740a..df53270 100644 --- a/desktop/src/main.rs +++ b/desktop/src/main.rs @@ -9,7 +9,7 @@ pub mod system_tray; use app_state::{app_state::AppState, app_state_factory}; use core::tauri_command::{ find_configuration, get_app_directory_path, get_inference_server_url, is_server_up, - upsert_configuration, + selected_profiles, upsert_configuration, }; use log::{info, LevelFilter}; use screencapture::tauri_command::{assert_screen_capture_permissions, capture_screen}; @@ -44,6 +44,7 @@ async fn main() { find_configuration, get_inference_server_url, get_app_directory_path, + selected_profiles, ]) .build(tauri::generate_context!()) .expect("error while building tauri application") diff --git a/webapp/src/app/settings/_components/inference-selected-profiles.tsx b/webapp/src/app/settings/_components/inference-selected-profiles.tsx new file mode 100644 index 0000000..895c8a1 --- /dev/null +++ b/webapp/src/app/settings/_components/inference-selected-profiles.tsx @@ -0,0 +1,14 @@ +import { ReadonlyKV } from '@/components/readonly-kv'; +import { getSelectedProfiles } from '@/lib/core-api/tauri'; +import { useAsync } from 'react-use'; + + + +export default function InferenceSelectedProfiles() { + const { value } = useAsync(() => getSelectedProfiles()); + + const profiles = value?.map((profile) => profile.name).join(', '); + return ( + {profiles} + ); +} diff --git a/webapp/src/app/settings/_components/internal-info.section.tsx b/webapp/src/app/settings/_components/internal-info.section.tsx index da6b56d..2f0adf8 100644 --- a/webapp/src/app/settings/_components/internal-info.section.tsx +++ b/webapp/src/app/settings/_components/internal-info.section.tsx @@ -2,6 +2,7 @@ import { Section } from '@/components/section'; import InferenceServerUrl from './inference-server-url'; import InferenceServerStatus from './inference-server-status'; import AppDataDirectory from './app-data-directory'; +import InferenceSelectedProfiles from './inference-selected-profiles'; export default function InferenceServerSection() { @@ -10,6 +11,7 @@ export default function InferenceServerSection() {
+
); diff --git a/webapp/src/app/settings/page.tsx b/webapp/src/app/settings/page.tsx index e55c6cd..f3c2319 100644 --- a/webapp/src/app/settings/page.tsx +++ b/webapp/src/app/settings/page.tsx @@ -11,10 +11,8 @@ export default function Settings() { return ( -
- - { process.env.NODE_ENV === 'development' && isInDesktopApp && } -
+ + { process.env.NODE_ENV === 'development' && isInDesktopApp && }
); } diff --git a/webapp/src/lib/core-api/tauri/index.ts b/webapp/src/lib/core-api/tauri/index.ts index 6758c75..e8aa61f 100644 --- a/webapp/src/lib/core-api/tauri/index.ts +++ b/webapp/src/lib/core-api/tauri/index.ts @@ -1,4 +1,5 @@ export * from './screen-capture'; export * from './configuration'; export * from './inference-server'; -export * from './info'; \ No newline at end of file +export * from './info'; +export * from './profile'; \ No newline at end of file diff --git a/webapp/src/lib/core-api/tauri/profile.ts b/webapp/src/lib/core-api/tauri/profile.ts new file mode 100644 index 0000000..83144b0 --- /dev/null +++ b/webapp/src/lib/core-api/tauri/profile.ts @@ -0,0 +1,10 @@ +import { invoke } from '@tauri-apps/api/tauri'; + +interface ProfileDto { + name: string; + prompt: string; +} + +export async function getSelectedProfiles(): Promise { + return await invoke('selected_profiles'); +}