From 9b6e0353558964f53a17c1e461bcd94832602c04 Mon Sep 17 00:00:00 2001 From: bwmac Date: Fri, 10 Jan 2025 10:59:51 -0500 Subject: [PATCH] updates agent.py --- synapseclient/models/agent.py | 107 +++++++++++++++++----------------- 1 file changed, 53 insertions(+), 54 deletions(-) diff --git a/synapseclient/models/agent.py b/synapseclient/models/agent.py index 0b6417a8e..a479a303e 100644 --- a/synapseclient/models/agent.py +++ b/synapseclient/models/agent.py @@ -1,7 +1,7 @@ from dataclasses import dataclass, field from datetime import datetime from enum import Enum -from typing import Dict, List, Optional +from typing import Dict, List, Optional, Union from synapseclient import Synapse from synapseclient.api import ( @@ -17,22 +17,27 @@ from synapseclient.core.async_utils import otel_trace_method -class AgentType(Enum): +class AgentType(str, Enum): """ Enum representing the type of agent as defined in - 'BASELINE' is a default agent provided by Synapse. - 'CUSTOM' is a custom agent that has been registered by a user. + + - BASELINE is a default agent provided by Synapse. + - CUSTOM is a custom agent that has been registered by a user. """ BASELINE = "BASELINE" CUSTOM = "CUSTOM" -class AgentSessionAccessLevel(Enum): +class AgentSessionAccessLevel(str, Enum): """ Enum representing the access level of the agent session as defined in + + - PUBLICLY_ACCESSIBLE: The agent can only access publicly accessible data. + - READ_YOUR_PRIVATE_DATA: The agent can read the user's private data. + - WRITE_YOUR_PRIVATE_DATA: The agent can write to the user's private data. """ PUBLICLY_ACCESSIBLE = "PUBLICLY_ACCESSIBLE" @@ -61,9 +66,10 @@ class AgentPrompt: """The response from the agent.""" trace: Optional[str] = None - """The trace or "though process" of the agent when responding to the prompt.""" + """The trace or "thought process" of the agent when responding to the prompt.""" +# TODO Add example usage to the docstring @dataclass class AgentSession: """Represents a [Synapse Agent Session](https://rest-docs.synapse.org/rest/org/sagebionetworks/repo/model/agent/AgentSession.html) @@ -77,6 +83,7 @@ class AgentSession: modified_on: The date the agent session was last modified. agent_registration_id: The registration ID of the agent that will be used for this session. etag: The etag of the agent session. + """ id: Optional[str] = None @@ -129,9 +136,7 @@ def fill_from_dict(self, synapse_agent_session: Dict[str, str]) -> "AgentSession self.etag = synapse_agent_session.get("etag", None) return self - @otel_trace_method( - method_to_trace_name=lambda self, **kwargs: f"Start_Session: {self.id}" - ) + @otel_trace_method(method_to_trace_name=lambda self, **kwargs: "Start_Session") async def start_async( self, *, synapse_client: Optional[Synapse] = None ) -> "AgentSession": @@ -143,11 +148,10 @@ async def start_async( Returns: The new AgentSession object. """ - syn = Synapse.get_client(synapse_client=synapse_client) session_response = await start_session( - access_level=self.access_level.value, + access_level=self.access_level, agent_registration_id=self.agent_registration_id, - synapse_client=syn, + synapse_client=synapse_client, ) return self.fill_from_dict(session_response) @@ -165,10 +169,9 @@ async def get_async( Returns: The retrieved AgentSession object. """ - syn = Synapse.get_client(synapse_client=synapse_client) session_response = await get_session( id=self.id, - synapse_client=syn, + synapse_client=synapse_client, ) return self.fill_from_dict(synapse_agent_session=session_response) @@ -178,10 +181,10 @@ async def get_async( async def update_async( self, *, - access_level: AgentSessionAccessLevel, synapse_client: Optional[Synapse] = None, ) -> "AgentSession": - """Updates an agent session. Only updates to the access level are currently supported. + """Updates an agent session. + Only updates to the access level are currently supported. Arguments: synapse_client: The Synapse client to use for the request. If None, the default client will be used. @@ -189,24 +192,21 @@ async def update_async( Returns: The updated AgentSession object. """ - syn = Synapse.get_client(synapse_client=synapse_client) - - self.access_level = access_level session_response = await update_session( id=self.id, - access_level=self.access_level.value, - synapse_client=syn, + access_level=self.access_level, + synapse_client=synapse_client, ) return self.fill_from_dict(session_response) @otel_trace_method(method_to_trace_name=lambda self, **kwargs: f"Prompt: {self.id}") async def prompt_async( self, - *, prompt: str, enable_trace: bool = False, - newer_than: Optional[int] = None, print_response: bool = False, + newer_than: Optional[int] = None, + *, synapse_client: Optional[Synapse] = None, ) -> None: """Sends a prompt to the agent and adds the response to the AgentSession's chat history. @@ -214,22 +214,21 @@ async def prompt_async( Arguments: prompt: The prompt to send to the agent. enable_trace: Whether to enable trace for the prompt. + print_response: Whether to print the response to the console. newer_than: The timestamp to get trace results newer than. Defaults to None (all results). - print: Whether to print the response to the console. synapse_client: The Synapse client to use for the request. If None, the default client will be used. """ - syn = Synapse.get_client(synapse_client=synapse_client) prompt_response = await send_prompt( id=self.id, prompt=prompt, enable_trace=enable_trace, - synapse_client=syn, + synapse_client=synapse_client, ) prompt_id = prompt_response["token"] answer_response = await get_response( prompt_id=prompt_id, - synapse_client=syn, + synapse_client=synapse_client, ) response = answer_response["responseText"] @@ -237,7 +236,7 @@ async def prompt_async( trace_response = await get_trace( prompt_id=prompt_id, newer_than=newer_than, - synapse_client=syn, + synapse_client=synapse_client, ) trace = trace_response["page"][0]["message"] @@ -257,6 +256,7 @@ async def prompt_async( print(f"TRACE:\n{trace}") +# TODO Add example usage to the docstring @dataclass class Agent: """Represents a [Synapse Agent Registration](https://rest-docs.synapse.org/rest/org/sagebionetworks/repo/model/agent/AgentRegistration.html) @@ -288,8 +288,8 @@ class Agent: sessions: Dict[str, AgentSession] = field(default_factory=dict) """A dictionary of AgentSession objects, keyed by session ID.""" - current_session: Optional[str] = None - """The ID of the current session. Prompts will be sent to this session by default.""" + current_session: Optional[AgentSession] = None + """The current session. Prompts will be sent to this session by default.""" def fill_from_dict(self, agent_registration: Dict[str, str]) -> "Agent": """ @@ -354,10 +354,10 @@ async def get_async(self, *, synapse_client: Optional[Synapse] = None) -> "Agent ) async def start_session_async( self, - *, access_level: Optional[ AgentSessionAccessLevel ] = AgentSessionAccessLevel.PUBLICLY_ACCESSIBLE, + *, synapse_client: Optional[Synapse] = None, ) -> "AgentSession": """Starts an agent session. @@ -373,12 +373,11 @@ async def start_session_async( The new AgentSession object. """ access_level = AgentSessionAccessLevel(access_level) - syn = Synapse.get_client(synapse_client=synapse_client) session = await AgentSession( agent_registration_id=self.registration_id, access_level=access_level - ).start_async(synapse_client=syn) + ).start_async(synapse_client=synapse_client) self.sessions[session.id] = session - self.current_session = session.id + self.current_session = session return session @otel_trace_method( @@ -387,11 +386,12 @@ async def start_session_async( async def get_session_async( self, *, session_id: str, synapse_client: Optional[Synapse] = None ) -> "AgentSession": - syn = Synapse.get_client(synapse_client=synapse_client) - session = await AgentSession(id=session_id).get_async(synapse_client=syn) + session = await AgentSession(id=session_id).get_async( + synapse_client=synapse_client + ) if session.id not in self.sessions: self.sessions[session.id] = session - self.current_session = session.id + self.current_session = session return session @otel_trace_method( @@ -400,48 +400,47 @@ async def get_session_async( async def prompt( self, *, - session_id: Optional[str] = None, prompt: str, enable_trace: bool = False, - newer_than: Optional[int] = None, print_response: bool = False, + session: Optional[AgentSession] = None, + newer_than: Optional[int] = None, synapse_client: Optional[Synapse] = None, ) -> None: """Sends a prompt to the agent for the current session. If no session is currently active, a new session will be started. Arguments: - session_id: The ID of the session to send the prompt to. If None, the current session will be used. prompt: The prompt to send to the agent. enable_trace: Whether to enable trace for the prompt. - newer_than: The timestamp to get trace results newer than. Defaults to None (all results). print_response: Whether to print the response to the console. + session_id: The ID of the session to send the prompt to. If None, the current session will be used. + newer_than: The timestamp to get trace results newer than. Defaults to None (all results). synapse_client: The Synapse client to use for the request. If None, the default client will be used. """ - syn = Synapse.get_client(synapse_client=synapse_client) - - # TODO: Iron this out. It's a little confusing. - if session_id: - if session_id not in self.sessions: - await self.get_session_async(session_id=session_id, synapse_client=syn) + # TODO: Iron this out. Make sure we cover all cases. + if session: + if session.id not in self.sessions: + await self.get_session_async( + session_id=session.id, synapse_client=synapse_client + ) else: - self.current_session = session_id + self.current_session = session else: if not self.current_session: - await self.start_session_async(synapse_client=syn) + await self.start_session_async(synapse_client=synapse_client) - await self.sessions[self.current_session].prompt_async( + await self.current_session.prompt_async( prompt=prompt, enable_trace=enable_trace, newer_than=newer_than, print_response=print_response, - synapse_client=syn, + synapse_client=synapse_client, ) @otel_trace_method( method_to_trace_name=lambda self, **kwargs: f"Get_Agent_Session_Chat_History: {self.registration_id}" ) - def get_chat_history(self) -> List[AgentPrompt]: + def get_chat_history(self) -> Union[List[AgentPrompt], None]: """Gets the chat history for the current session.""" - # TODO: Is this the best way to do this? - return self.sessions[self.current_session].chat_history + return self.current_session.chat_history if self.current_session else None