Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat: extensions #50

Merged
merged 7 commits into from
Jul 30, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 0 additions & 1 deletion core/.env.example
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,5 @@ ANTHROPIC_API_KEY=""
LANGCHAIN_TRACING_V2=true
LANGCHAIN_ENDPOINT="https://api.smith.langchain.com"
LANGCHAIN_API_KEY=""
SYSTEM_FINGERPRINT="dev-local"
# Used for web search
SERPER_API_KEY=""
4 changes: 2 additions & 2 deletions core/ai_assistant_core/assistant/api/route.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
from fastapi import FastAPI
from injector import Injector
from langchain_openai_api_bridge.core import AgentFactory
from langchain_openai_api_bridge.core import BaseAgentFactory
from langchain_openai_api_bridge.assistant import (
ThreadRepository,
MessageRepository,
Expand All @@ -14,7 +14,7 @@
def bind_assistant_routes(app: FastAPI, injector: Injector):

bridge = LangchainOpenaiApiBridgeFastAPI(
app=app, agent_factory_provider=lambda: injector.get(AgentFactory)
app=app, agent_factory_provider=lambda: injector.get(BaseAgentFactory)
)
bridge.bind_openai_assistant_api(
thread_repository_provider=lambda: injector.get(ThreadRepository),
Expand Down
16 changes: 0 additions & 16 deletions core/ai_assistant_core/assistant/domain/agent_factory.py

This file was deleted.

Original file line number Diff line number Diff line change
Expand Up @@ -4,17 +4,18 @@
from ai_assistant_core.assistant.domain.prompt.user_system_prompt_factory import (
UserSystemPromptFactory,
)
from ..agent_factory import BaseAgentFactory
from langchain_core.language_models import BaseChatModel
from langchain_core.tools import BaseTool
from langgraph.graph.graph import CompiledGraph
from langgraph.prebuilt import create_react_agent


@inject
class DefaultAgentFactory(BaseAgentFactory):
class DefaultAgentFactory:
def __init__(
self, tools: Optional[list[BaseTool]], prompt_factory: UserSystemPromptFactory
self,
tools: Optional[list[BaseTool]],
prompt_factory: UserSystemPromptFactory,
) -> None:
self.tools = tools
self.prompt_factory = prompt_factory
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -10,14 +10,14 @@
ExtensionAsToolFactory,
)

from ..agent_factory import BaseAgentFactory

from langchain_core.language_models import BaseChatModel
from langchain_core.runnables import Runnable
from langgraph.prebuilt import create_react_agent


@inject
class ExtensionAgentFactory(BaseAgentFactory):
class ExtensionAgentFactory:
def __init__(
self,
extension_repository: BaseExtensionRepository,
Expand All @@ -37,10 +37,13 @@ def is_assistant_an_extension(self, assistant_id: str) -> bool:
def create(self, assistant_id: str, llm: BaseChatModel) -> Runnable:
extension_info = self.extension_repository.get_by_name(name=assistant_id)
extension = self.extension_service.load(extension=extension_info)
tool = self.extension_as_tool_factory.create(extension=extension, llm=llm)
extension_as_tool = self.extension_as_tool_factory.create(
extension=extension, llm=llm
)

return create_react_agent(
model=llm,
tools=[tool],
messages_modifier=f'No matter the input, call the tool "{tool.description}"',
tools=[extension_as_tool],
messages_modifier=f"""No matter the input, always use the following tool. If the input is not relevant, use the tool anyway with the input.
Name: "{extension_as_tool.name}". Description: "{extension_as_tool.description}".""",
)
29 changes: 22 additions & 7 deletions core/ai_assistant_core/assistant/domain/assistant_agent_factory.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,39 +8,54 @@
ExtensionAgentFactory,
)
from ai_assistant_core.llm.domain.llm_factory import LLMFactory
from langchain_openai_api_bridge.core.agent_factory import AgentFactory, CreateAgentDto
from langgraph.graph.graph import CompiledGraph
from langchain_openai_api_bridge.core.create_agent_dto import CreateAgentDto
from langchain_openai_api_bridge.core import BaseAgentFactory
from langchain_openai_api_bridge.assistant import (
ThreadRepository,
)


@inject
class AssistantAgentFactory(AgentFactory):
class AssistantAgentFactory(BaseAgentFactory):

def __init__(
self,
llm_factory: LLMFactory,
default_agent_factory: DefaultAgentFactory,
extension_agent_factory: ExtensionAgentFactory,
thread_repository: ThreadRepository,
) -> None:
self.llm_factory = llm_factory
self.default_agent_factory = default_agent_factory
self.extension_agent_factory = extension_agent_factory
self.thread_repository = thread_repository

def create_agent(self, llm: BaseChatModel, dto: CreateAgentDto) -> Runnable:
def create_agent(self, dto: CreateAgentDto) -> Runnable:
llm = self.create_llm(dto)
factory = self.default_agent_factory.create
assistant_id = self._get_assistant_id(dto)

if self.extension_agent_factory.is_assistant_an_extension(
assistant_id=dto.assistant_id
assistant_id=assistant_id
):
factory = self.extension_agent_factory.create

return factory(
assistant_id=dto.assistant_id,
assistant_id=assistant_id,
llm=llm,
)

def create_llm(self, dto: CreateAgentDto) -> CompiledGraph:
def create_llm(self, dto: CreateAgentDto) -> BaseChatModel:
return self.llm_factory.create_chat_model(
vendor_model=dto.model,
max_tokens=dto.max_tokens,
temperature=dto.temperature,
)

def _get_assistant_id(self, dto: CreateAgentDto) -> str:
thread = self.thread_repository.retreive(thread_id=dto.thread_id)
thread_assistant_id = None
if thread is not None and thread.metadata is not None:
thread_assistant_id = thread.metadata.get("assistantId", None)

return thread_assistant_id or dto.assistant_id
4 changes: 2 additions & 2 deletions core/ai_assistant_core/assistant/module.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@
MessageRepository,
RunRepository,
)
from langchain_openai_api_bridge.core import AgentFactory
from langchain_openai_api_bridge.core import BaseAgentFactory

from ai_assistant_core.assistant.domain.user_info_service import (
UserInfo,
Expand All @@ -27,7 +27,7 @@ def configure(self, binder: Binder):
binder.bind(ThreadRepository, to=SqlalchemyThreadRepository, scope=singleton)
binder.bind(MessageRepository, to=SqlalchemyMessageRepository, scope=singleton)
binder.bind(RunRepository, to=SqlalchemyRunRepository, scope=singleton)
binder.bind(AgentFactory, to=AssistantAgentFactory)
binder.bind(BaseAgentFactory, to=AssistantAgentFactory)

@provider
def provide_user_info(self, service: UserInfoService) -> UserInfo:
Expand Down
3 changes: 0 additions & 3 deletions core/ai_assistant_core/main.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,3 @@
import os
from typing import Optional
import uvicorn
import argparse
Expand All @@ -20,8 +19,6 @@
from ai_assistant_core.extension import extension_router, ExtensionModule
from ai_assistant_core.tools import ToolsModule

os.environ["LANGCHAIN_TRACING_V2"] = "false"


def create_app(database_url: Optional[str] = None) -> FastAPI:
injector = Injector(
Expand Down
8 changes: 4 additions & 4 deletions core/poetry.lock

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

2 changes: 1 addition & 1 deletion core/pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@ langchain-anthropic = "^0.1.13"
uvicorn = "^0.30.0"
langchain-community = "^0.2.1"
tiktoken = "^0.7.0"
langchain-openai-api-bridge = "0.9.1"
langchain-openai-api-bridge = "0.10.1"
beautifulsoup4 = "^4.12.3"
sqlalchemy = "^2.0.31"
injector = "^0.21.0"
Expand Down
Loading