diff --git a/app/agents/random_number.py b/app/agents/random_number.py index b718ebf..c95d5eb 100644 --- a/app/agents/random_number.py +++ b/app/agents/random_number.py @@ -34,6 +34,7 @@ class Input(BaseModel): class Output(BaseModel): output: Any + random_number_agent_executor = AgentExecutor( agent=random_number_agent, tools=tools, diff --git a/app/agents/util.py b/app/agents/util.py index 8512ccb..432fd5e 100644 --- a/app/agents/util.py +++ b/app/agents/util.py @@ -27,11 +27,13 @@ def create_ollama_functions_agent( functions=[DEFAULT_RESPONSE_FUNCTION] + [convert_to_openai_function(t) for t in tools], format="json", ) + + def agent_scratchpad(x): + return adapt_to_ollama_messages(format_to_openai_function_messages(x["intermediate_steps"])) + agent = ( RunnablePassthrough.assign( - agent_scratchpad=lambda x: adapt_to_ollama_messages(format_to_openai_function_messages( - x["intermediate_steps"] - )) + agent_scratchpad=agent_scratchpad, ) | prompt | llm_with_tools diff --git a/app/chains/supervisor.py b/app/chains/supervisor.py index 0f698a0..2e08f84 100644 --- a/app/chains/supervisor.py +++ b/app/chains/supervisor.py @@ -3,6 +3,8 @@ from app.dependencies.openai_chat_model import openai_chat_model +# https://github.com/langchain-ai/langgraph/blob/main/examples/multi_agent/agent_supervisor.ipynb + system_prompt = ( "You are a supervisor tasked with managing a conversation between the" " following workers: {members}. Given the following user request," diff --git a/app/dependencies/slack.py b/app/dependencies/slack.py index 11d8b5b..b8ccefe 100644 --- a/app/dependencies/slack.py +++ b/app/dependencies/slack.py @@ -1,12 +1,14 @@ from typing import Annotated from fastapi import Depends +from langchain_core.messages import HumanMessage from slack_bolt import App from slack_bolt.adapter.fastapi import SlackRequestHandler from slack_bolt.adapter.socket_mode import SocketModeHandler from slack_sdk import WebClient from app import config +from app.graph import graph app = App( token=config.SLACK_BOT_TOKEN, @@ -27,16 +29,26 @@ def get_slack_request_handler() -> SlackRequestHandler: SlackRequestHandlerDep = Annotated[SlackRequestHandler, Depends(get_slack_request_handler)] +def invoke_graph(body) -> str: + text = body.get('event', {}).get('text', '') + response = graph.invoke({ + "messages": [HumanMessage(content=text)], + }) + return response["messages"][-1].content + + @app.event("message") def handle_message_events(body, say, logger): + if body.get('event', {}).get('channel_type') != 'im': + return logger.info(body) - text = body.get('event', {}).get('text', '') - if 'hello' in text.lower(): - say("Hello there! :wave:") + response = invoke_graph(body) + say(response) # Define an event listener for "app_mention" events @app.event("app_mention") def handle_app_mention_events(body, say, logger): logger.info(body) - say("Hello SRE slackers! :blush:") + response = invoke_graph(body) + say(response) diff --git a/app/graph.py b/app/graph.py new file mode 100644 index 0000000..81e929e --- /dev/null +++ b/app/graph.py @@ -0,0 +1,81 @@ +import functools +import operator +from typing import Annotated, Sequence, TypedDict + +from langchain.agents import create_openai_tools_agent, AgentExecutor +from langchain_core.messages import BaseMessage, HumanMessage +from langchain_core.prompts import ChatPromptTemplate, MessagesPlaceholder +from langchain_openai import ChatOpenAI +from langgraph.graph import StateGraph, END +from langgraph.pregel import Pregel + +from app.chains.supervisor import build_supervisor_chain +from app.dependencies.openai_chat_model import openai_chat_model +from app.tools.random_number import random_number +from app.tools.random_select import random_select + + +# https://github.com/langchain-ai/langgraph/blob/main/examples/multi_agent/agent_supervisor.ipynb + +def create_agent(llm: ChatOpenAI, tools: list, system_prompt: str): + prompt = ChatPromptTemplate.from_messages( + [ + ( + "system", + system_prompt, + ), + MessagesPlaceholder(variable_name="messages"), + MessagesPlaceholder(variable_name="agent_scratchpad"), + ] + ) + agent = create_openai_tools_agent(llm, tools, prompt) + executor = AgentExecutor(agent=agent, tools=tools) + return executor + + +def agent_node(state, agent, name): + result = agent.invoke(state) + return {"messages": [HumanMessage(content=result["output"], name=name)]} + + +class AgentState(TypedDict): + messages: Annotated[Sequence[BaseMessage], operator.add] + next: str + + +SUPERVISOR_NAME = "Supervisor" + +GRAPH = { + "RandomNumber": { + "tools": [random_number], + "system_prompt": "You are a random number generator.", + }, + "RandomSelect": { + "tools": [random_select], + "system_prompt": "You are a random selector.", + } +} + + +def build_graph() -> Pregel: + members = list(GRAPH.keys()) + supervisor_chain = build_supervisor_chain(members) + + workflow = StateGraph(AgentState) + for member, config in GRAPH.items(): + agent = create_agent(openai_chat_model, config["tools"], config["system_prompt"]) + workflow.add_node(member, functools.partial(agent_node, agent=agent, name=member)) + workflow.add_node(SUPERVISOR_NAME, supervisor_chain) + + for member in members: + workflow.add_edge(member, SUPERVISOR_NAME) + + conditional_map = {k: k for k in members} + conditional_map["FINISH"] = END + workflow.add_conditional_edges(SUPERVISOR_NAME, lambda x: x["next"], conditional_map) + workflow.set_entry_point(SUPERVISOR_NAME) + + return workflow.compile() + + +graph = build_graph() diff --git a/app/server.py b/app/server.py index d8c0a93..65ce22f 100644 --- a/app/server.py +++ b/app/server.py @@ -6,6 +6,7 @@ from app.chains.extraction import extraction_chain from app.chains.supervisor import build_supervisor_chain from app.dependencies.ollama_chat_model import ollama_chat_model +from app.graph import graph from app.routers import slack app = FastAPI() @@ -40,6 +41,12 @@ async def redirect_root_to_docs(): path="/supervisor", ) +add_routes( + app, + graph, + path="/graph", +) + app.include_router(slack.router) if __name__ == "__main__": diff --git a/app/tools/random_number.py b/app/tools/random_number.py index 9048e9b..5ae00f1 100644 --- a/app/tools/random_number.py +++ b/app/tools/random_number.py @@ -4,8 +4,11 @@ @tool() -def random_number() -> str: +def random_number( + lower: int = 0, + upper: int = 100, +) -> int: """ - Generate a random number between 0 and 100 + Generate a random number between lower (default 0) and upper (default 100) """ - return str(random.randint(0, 100)) + return random.randint(lower, upper) diff --git a/app/tools/random_select.py b/app/tools/random_select.py new file mode 100644 index 0000000..6a82ec2 --- /dev/null +++ b/app/tools/random_select.py @@ -0,0 +1,13 @@ +import random +from collections.abc import Sequence +from typing import Any + +from langchain_core.tools import tool + + +@tool() +def random_select(items: Sequence[Any], k: int) -> list[Any]: + """ + Randomly select k items from a list + """ + return random.sample(items, k)