Skip to content

Commit

Permalink
feat: injected tools
Browse files Browse the repository at this point in the history
  • Loading branch information
samuelint committed Jul 2, 2024
1 parent c43cef1 commit 457fa0f
Show file tree
Hide file tree
Showing 16 changed files with 135 additions and 127 deletions.
21 changes: 21 additions & 0 deletions core/ai_assistant_core/assistant/domain/agent_builder.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,21 @@
from typing import Optional
from .assistant_prompt_builder import AssistantPromptBuilder
from langchain_core.language_models import BaseChatModel
from langchain_core.tools import BaseTool
from langgraph.prebuilt import create_react_agent


class AgentBuilder:
def __init__(
self, llm: BaseChatModel, tools: Optional[list[BaseTool]] = None
) -> None:
self.llm = llm
self.tools = tools or []

def build(self):
prompt_builder = AssistantPromptBuilder(person_name="Samuel Magny")
system_prompt = prompt_builder.build_system_prompt()

return create_react_agent(
self.llm, tools=self.tools, messages_modifier=system_prompt
)
Original file line number Diff line number Diff line change
@@ -1,19 +1,24 @@
from typing import Optional
from injector import inject
from .assistant_builder import AssistantBuilder
from .agent_builder import AgentBuilder
from ai_assistant_core.llm.domain.llm_factory import LLMFactory
from langchain_openai_api_bridge.core.agent_factory import AgentFactory, CreateAgentDto
from langgraph.graph.graph import CompiledGraph
from langchain_core.language_models import BaseChatModel
from langchain_core.tools import BaseTool


class AssistantAgentFactory(AgentFactory):

@inject
def __init__(self, llm_factory: LLMFactory) -> None:
def __init__(
self, llm_factory: LLMFactory, tools: Optional[list[BaseTool]] = None
) -> None:
self.llm_factory = llm_factory
self.tools = tools

def create_agent(self, llm: BaseChatModel, dto: CreateAgentDto) -> CompiledGraph:
return AssistantBuilder(llm).build()
return AgentBuilder(llm=llm, tools=self.tools).build()

def create_llm(self, dto: CreateAgentDto) -> CompiledGraph:
return self.llm_factory.create_chat_model(
Expand Down
36 changes: 0 additions & 36 deletions core/ai_assistant_core/assistant/domain/assistant_builder.py

This file was deleted.

Original file line number Diff line number Diff line change
@@ -1,9 +1,7 @@
# flake8: noqa
from langchain_core.prompts import PipelinePromptTemplate, PromptTemplate
from langchain.output_parsers import ResponseSchema, StructuredOutputParser
from ai_assistant_core.prompts.confidence_level import (
confidence_percent_response_schema,
)
from langchain.output_parsers import ResponseSchema


role_template = """\
#role
Expand Down Expand Up @@ -33,20 +31,13 @@


class AssistantPromptBuilder:
def __init__(self, person_name=None, confidence_percentage=True):
def __init__(
self,
person_name=None,
):
self.person_name = person_name
self.confidence_percentage = confidence_percentage

def build_output_parser(self) -> StructuredOutputParser:
response_schemas = [text_response]
if self.confidence_percentage:
response_schemas.append(confidence_percent_response_schema)

return StructuredOutputParser.from_response_schemas(response_schemas)

def build_system_prompt(self) -> str:
# output_parser = self.build_output_parser()
# format_instructions = output_parser.get_format_instructions()
system_prompt = PromptTemplate.from_template(system_prompt_template)

pipeline_prompt_template = PipelinePromptTemplate(
Expand Down
5 changes: 5 additions & 0 deletions core/ai_assistant_core/infrastructure/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,5 @@
from .sqlalchemy_module import SqlAlchemyModule

__all__ = [
"SqlAlchemyModule",
]
5 changes: 5 additions & 0 deletions core/ai_assistant_core/llm/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,5 @@
from .module import LLMModule

__all__ = [
"LLMModule",
]
7 changes: 5 additions & 2 deletions core/ai_assistant_core/llm/domain/api_key_service.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,10 @@ def __init__(self, configuration_repo: ConfigurationRepository):
self.configuration_repo = configuration_repo

def get_openai_api_key(self) -> str:
return self.configuration_repo.get("OPENAI_API_KEY").value
return self.get("OPENAI_API_KEY")

def get_anthropic_api_key(self) -> str:
return self.configuration_repo.get("ANTHROPIC_API_KEY").value
return self.get("ANTHROPIC_API_KEY")

def get(self, key: str) -> str:
return self.configuration_repo.get(key).value
8 changes: 5 additions & 3 deletions core/ai_assistant_core/llm/module.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
from injector import Injector, Module, multiprovider
from injector import Module, multiprovider

from ai_assistant_core.llm.domain.base_llm_factory import BaseLLMFactory
from ai_assistant_core.llm.infrastructure.anthropic_llm_factory import (
Expand All @@ -9,5 +9,7 @@

class LLMModule(Module):
@multiprovider
def provide_llm_factories(self, injector: Injector) -> list[BaseLLMFactory]:
return [injector.get(OpenAILLMFactory), injector.get(AnthropicLLMFactory)]
def provide_llm_factories(
self, openai_factory: OpenAILLMFactory, anthropic_factory: AnthropicLLMFactory
) -> list[BaseLLMFactory]:
return [openai_factory, anthropic_factory]
15 changes: 9 additions & 6 deletions core/ai_assistant_core/main.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,15 +6,17 @@
from injector import Injector
from fastapi_injector import attach_injector

from ai_assistant_core.app_configuration import (
AppConfigurationModule,
)
from ai_assistant_core.assistant import AssistantModule, bind_assistant_routes
from ai_assistant_core.configuration.module import ConfigurationModule
from ai_assistant_core.health.route import bind_health_routes
from ai_assistant_core.infrastructure.sqlalchemy_module import SqlAlchemyModule
from ai_assistant_core.configuration import configuration_kv_router
from ai_assistant_core.llm.module import LLMModule

from ai_assistant_core.infrastructure import SqlAlchemyModule
from ai_assistant_core.app_configuration import (
AppConfigurationModule,
)
from ai_assistant_core.configuration import ConfigurationModule
from ai_assistant_core.llm import LLMModule
from ai_assistant_core.tools import ToolsModule


def create_app(database_url: Optional[str] = None) -> FastAPI:
Expand All @@ -24,6 +26,7 @@ def create_app(database_url: Optional[str] = None) -> FastAPI:
ConfigurationModule(),
LLMModule(),
AssistantModule(),
ToolsModule(),
SqlAlchemyModule(),
]
)
Expand Down
7 changes: 0 additions & 7 deletions core/ai_assistant_core/prompts/confidence_level.py

This file was deleted.

13 changes: 2 additions & 11 deletions core/ai_assistant_core/tools/__init__.py
Original file line number Diff line number Diff line change
@@ -1,14 +1,5 @@
from .magic_number.magic_number import magic_number_tool
from .web_search.web_search import web_search_tool
from .image_generation.dalle import dall_e_tool
from .tokenizer.tiktoken_tool import token_size_tool
from .webpage_loader.url_content_loader import url_content_loader_tool

from .module import ToolsModule

__all__ = [
"magic_number_tool",
"web_search_tool",
"dall_e_tool",
"token_size_tool",
"url_content_loader_tool",
"ToolsModule",
]
31 changes: 16 additions & 15 deletions core/ai_assistant_core/tools/image_generation/dalle.py
Original file line number Diff line number Diff line change
@@ -1,21 +1,22 @@
import os
from typing import Union
from langchain_core.tools import tool
from injector import inject
from langchain_community.utilities.dalle_image_generator import DallEAPIWrapper
from langchain_core.tools import BaseTool, StructuredTool

from ai_assistant_core.llm.domain.api_key_service import ApiKeyService

dall_e_client: Union[DallEAPIWrapper, None] = None

@inject
class DallEToolFactory:
def __init__(self, api_key_service: ApiKeyService) -> None:
api_key = api_key_service.get_openai_api_key()
self.dalle_client = DallEAPIWrapper(api_key=api_key, model="dall-e-3")

def _get_dall_e_client():
global dall_e_client
if dall_e_client is None:
dall_e_client = DallEAPIWrapper(openai_api_key=os.getenv("OPENAI_API_KEY"))
return dall_e_client
def generate_image(self, query: str) -> str:
return self.dalle_client.run(query)


@tool
def dall_e_tool(query: str) -> str:
"""Generate an image using DALL-E. Returns image URL"""
dall_e = _get_dall_e_client()
return dall_e.run(query)
def create(self) -> BaseTool:
return StructuredTool.from_function(
func=self.generate_image,
name="dall_e",
description="Generate an image using DALL-E.",
)
26 changes: 26 additions & 0 deletions core/ai_assistant_core/tools/module.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,26 @@
from injector import Module, multiprovider
from langchain_core.tools import BaseTool

from ai_assistant_core.tools.image_generation.dalle import DallEToolFactory
from ai_assistant_core.tools.magic_number import (
magic_number_tool,
)
from ai_assistant_core.tools.web_search.web_search import WebSearchToolFactory
from ai_assistant_core.tools.webpage_loader.url_content_loader import (
url_content_loader_tool,
)


class ToolsModule(Module):
@multiprovider
def provide_tools(
self,
web_search_tool_factory: WebSearchToolFactory,
dalle_tool_factory: DallEToolFactory,
) -> list[BaseTool]:
return [
magic_number_tool,
url_content_loader_tool,
web_search_tool_factory.create(),
dalle_tool_factory.create(),
]
12 changes: 0 additions & 12 deletions core/ai_assistant_core/tools/tokenizer/tiktoken_tool.py

This file was deleted.

42 changes: 26 additions & 16 deletions core/ai_assistant_core/tools/web_search/web_search.py
Original file line number Diff line number Diff line change
@@ -1,22 +1,32 @@
import os
from typing import Union
from langchain_core.tools import tool
from injector import inject

from langchain.pydantic_v1 import BaseModel, Field
from langchain_core.tools import BaseTool, StructuredTool
from langchain_community.utilities import GoogleSerperAPIWrapper

serper_client: Union[GoogleSerperAPIWrapper, None] = None
from ai_assistant_core.llm.domain.api_key_service import ApiKeyService


def _get_serper_client():
global serper_client
if serper_client is None:
serper_client = GoogleSerperAPIWrapper(
serper_api_key=os.getenv("SERPER_API_KEY")
)
return serper_client
class WebSearchInput(BaseModel):
query: str = Field(description="query to search")


@inject
class WebSearchToolFactory:
def __init__(self, api_key_service: ApiKeyService) -> None:
api_key = api_key_service.get("SERPER_API_KEY")
self.serper_client = GoogleSerperAPIWrapper(serper_api_key=api_key)

@tool
def web_search_tool(query: str) -> str:
"""Search the web (internet) about a given topic"""
serper_client = _get_serper_client()
return serper_client.run(query)
def search_web(self, query: str) -> str:
return self.serper_client.run(query)

def asearch_web(self, query: str) -> str:
return self.serper_client.arun(query)

def create(self) -> BaseTool:
return StructuredTool.from_function(
func=self.search_web,
coroutine=self.asearch_web,
name="search_web",
description="useful for when you need to answer questions about current events",
)
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@

@tool("url_content_loader")
def url_content_loader_tool(url: str) -> dict:
"""Load url content"""
"""Load content from url"""

loader = RecursiveUrlLoader(
url=url,
Expand Down

0 comments on commit 457fa0f

Please sign in to comment.