-
-
Notifications
You must be signed in to change notification settings - Fork 0
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
- Loading branch information
Showing
16 changed files
with
135 additions
and
127 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,21 @@ | ||
from typing import Optional | ||
from .assistant_prompt_builder import AssistantPromptBuilder | ||
from langchain_core.language_models import BaseChatModel | ||
from langchain_core.tools import BaseTool | ||
from langgraph.prebuilt import create_react_agent | ||
|
||
|
||
class AgentBuilder: | ||
def __init__( | ||
self, llm: BaseChatModel, tools: Optional[list[BaseTool]] = None | ||
) -> None: | ||
self.llm = llm | ||
self.tools = tools or [] | ||
|
||
def build(self): | ||
prompt_builder = AssistantPromptBuilder(person_name="Samuel Magny") | ||
system_prompt = prompt_builder.build_system_prompt() | ||
|
||
return create_react_agent( | ||
self.llm, tools=self.tools, messages_modifier=system_prompt | ||
) |
11 changes: 8 additions & 3 deletions
11
core/ai_assistant_core/assistant/domain/assistant_agent_factory.py
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
36 changes: 0 additions & 36 deletions
36
core/ai_assistant_core/assistant/domain/assistant_builder.py
This file was deleted.
Oops, something went wrong.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,5 @@ | ||
from .sqlalchemy_module import SqlAlchemyModule | ||
|
||
__all__ = [ | ||
"SqlAlchemyModule", | ||
] |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,5 @@ | ||
from .module import LLMModule | ||
|
||
__all__ = [ | ||
"LLMModule", | ||
] |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file was deleted.
Oops, something went wrong.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,14 +1,5 @@ | ||
from .magic_number.magic_number import magic_number_tool | ||
from .web_search.web_search import web_search_tool | ||
from .image_generation.dalle import dall_e_tool | ||
from .tokenizer.tiktoken_tool import token_size_tool | ||
from .webpage_loader.url_content_loader import url_content_loader_tool | ||
|
||
from .module import ToolsModule | ||
|
||
__all__ = [ | ||
"magic_number_tool", | ||
"web_search_tool", | ||
"dall_e_tool", | ||
"token_size_tool", | ||
"url_content_loader_tool", | ||
"ToolsModule", | ||
] |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,21 +1,22 @@ | ||
import os | ||
from typing import Union | ||
from langchain_core.tools import tool | ||
from injector import inject | ||
from langchain_community.utilities.dalle_image_generator import DallEAPIWrapper | ||
from langchain_core.tools import BaseTool, StructuredTool | ||
|
||
from ai_assistant_core.llm.domain.api_key_service import ApiKeyService | ||
|
||
dall_e_client: Union[DallEAPIWrapper, None] = None | ||
|
||
@inject | ||
class DallEToolFactory: | ||
def __init__(self, api_key_service: ApiKeyService) -> None: | ||
api_key = api_key_service.get_openai_api_key() | ||
self.dalle_client = DallEAPIWrapper(api_key=api_key, model="dall-e-3") | ||
|
||
def _get_dall_e_client(): | ||
global dall_e_client | ||
if dall_e_client is None: | ||
dall_e_client = DallEAPIWrapper(openai_api_key=os.getenv("OPENAI_API_KEY")) | ||
return dall_e_client | ||
def generate_image(self, query: str) -> str: | ||
return self.dalle_client.run(query) | ||
|
||
|
||
@tool | ||
def dall_e_tool(query: str) -> str: | ||
"""Generate an image using DALL-E. Returns image URL""" | ||
dall_e = _get_dall_e_client() | ||
return dall_e.run(query) | ||
def create(self) -> BaseTool: | ||
return StructuredTool.from_function( | ||
func=self.generate_image, | ||
name="dall_e", | ||
description="Generate an image using DALL-E.", | ||
) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,26 @@ | ||
from injector import Module, multiprovider | ||
from langchain_core.tools import BaseTool | ||
|
||
from ai_assistant_core.tools.image_generation.dalle import DallEToolFactory | ||
from ai_assistant_core.tools.magic_number import ( | ||
magic_number_tool, | ||
) | ||
from ai_assistant_core.tools.web_search.web_search import WebSearchToolFactory | ||
from ai_assistant_core.tools.webpage_loader.url_content_loader import ( | ||
url_content_loader_tool, | ||
) | ||
|
||
|
||
class ToolsModule(Module): | ||
@multiprovider | ||
def provide_tools( | ||
self, | ||
web_search_tool_factory: WebSearchToolFactory, | ||
dalle_tool_factory: DallEToolFactory, | ||
) -> list[BaseTool]: | ||
return [ | ||
magic_number_tool, | ||
url_content_loader_tool, | ||
web_search_tool_factory.create(), | ||
dalle_tool_factory.create(), | ||
] |
This file was deleted.
Oops, something went wrong.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,22 +1,32 @@ | ||
import os | ||
from typing import Union | ||
from langchain_core.tools import tool | ||
from injector import inject | ||
|
||
from langchain.pydantic_v1 import BaseModel, Field | ||
from langchain_core.tools import BaseTool, StructuredTool | ||
from langchain_community.utilities import GoogleSerperAPIWrapper | ||
|
||
serper_client: Union[GoogleSerperAPIWrapper, None] = None | ||
from ai_assistant_core.llm.domain.api_key_service import ApiKeyService | ||
|
||
|
||
def _get_serper_client(): | ||
global serper_client | ||
if serper_client is None: | ||
serper_client = GoogleSerperAPIWrapper( | ||
serper_api_key=os.getenv("SERPER_API_KEY") | ||
) | ||
return serper_client | ||
class WebSearchInput(BaseModel): | ||
query: str = Field(description="query to search") | ||
|
||
|
||
@inject | ||
class WebSearchToolFactory: | ||
def __init__(self, api_key_service: ApiKeyService) -> None: | ||
api_key = api_key_service.get("SERPER_API_KEY") | ||
self.serper_client = GoogleSerperAPIWrapper(serper_api_key=api_key) | ||
|
||
@tool | ||
def web_search_tool(query: str) -> str: | ||
"""Search the web (internet) about a given topic""" | ||
serper_client = _get_serper_client() | ||
return serper_client.run(query) | ||
def search_web(self, query: str) -> str: | ||
return self.serper_client.run(query) | ||
|
||
def asearch_web(self, query: str) -> str: | ||
return self.serper_client.arun(query) | ||
|
||
def create(self) -> BaseTool: | ||
return StructuredTool.from_function( | ||
func=self.search_web, | ||
coroutine=self.asearch_web, | ||
name="search_web", | ||
description="useful for when you need to answer questions about current events", | ||
) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters