Skip to content

Commit

Permalink
updates agent.py
Browse files Browse the repository at this point in the history
  • Loading branch information
BWMac committed Jan 10, 2025
1 parent 05e73f3 commit 9b6e035
Showing 1 changed file with 53 additions and 54 deletions.
107 changes: 53 additions & 54 deletions synapseclient/models/agent.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
from dataclasses import dataclass, field
from datetime import datetime
from enum import Enum
from typing import Dict, List, Optional
from typing import Dict, List, Optional, Union

from synapseclient import Synapse
from synapseclient.api import (
Expand All @@ -17,22 +17,27 @@
from synapseclient.core.async_utils import otel_trace_method


class AgentType(Enum):
class AgentType(str, Enum):
"""
Enum representing the type of agent as defined in
<https://rest-docs.synapse.org/rest/org/sagebionetworks/repo/model/agent/AgentType.html>
'BASELINE' is a default agent provided by Synapse.
'CUSTOM' is a custom agent that has been registered by a user.
- BASELINE is a default agent provided by Synapse.
- CUSTOM is a custom agent that has been registered by a user.
"""

BASELINE = "BASELINE"
CUSTOM = "CUSTOM"


class AgentSessionAccessLevel(Enum):
class AgentSessionAccessLevel(str, Enum):
"""
Enum representing the access level of the agent session as defined in
<https://rest-docs.synapse.org/rest/org/sagebionetworks/repo/model/agent/AgentAccessLevel.html>
- PUBLICLY_ACCESSIBLE: The agent can only access publicly accessible data.
- READ_YOUR_PRIVATE_DATA: The agent can read the user's private data.
- WRITE_YOUR_PRIVATE_DATA: The agent can write to the user's private data.
"""

PUBLICLY_ACCESSIBLE = "PUBLICLY_ACCESSIBLE"
Expand Down Expand Up @@ -61,9 +66,10 @@ class AgentPrompt:
"""The response from the agent."""

trace: Optional[str] = None
"""The trace or "though process" of the agent when responding to the prompt."""
"""The trace or "thought process" of the agent when responding to the prompt."""


# TODO Add example usage to the docstring
@dataclass
class AgentSession:
"""Represents a [Synapse Agent Session](https://rest-docs.synapse.org/rest/org/sagebionetworks/repo/model/agent/AgentSession.html)
Expand All @@ -77,6 +83,7 @@ class AgentSession:
modified_on: The date the agent session was last modified.
agent_registration_id: The registration ID of the agent that will be used for this session.
etag: The etag of the agent session.
"""

id: Optional[str] = None
Expand Down Expand Up @@ -129,9 +136,7 @@ def fill_from_dict(self, synapse_agent_session: Dict[str, str]) -> "AgentSession
self.etag = synapse_agent_session.get("etag", None)
return self

@otel_trace_method(
method_to_trace_name=lambda self, **kwargs: f"Start_Session: {self.id}"
)
@otel_trace_method(method_to_trace_name=lambda self, **kwargs: "Start_Session")
async def start_async(
self, *, synapse_client: Optional[Synapse] = None
) -> "AgentSession":
Expand All @@ -143,11 +148,10 @@ async def start_async(
Returns:
The new AgentSession object.
"""
syn = Synapse.get_client(synapse_client=synapse_client)
session_response = await start_session(
access_level=self.access_level.value,
access_level=self.access_level,
agent_registration_id=self.agent_registration_id,
synapse_client=syn,
synapse_client=synapse_client,
)
return self.fill_from_dict(session_response)

Expand All @@ -165,10 +169,9 @@ async def get_async(
Returns:
The retrieved AgentSession object.
"""
syn = Synapse.get_client(synapse_client=synapse_client)
session_response = await get_session(
id=self.id,
synapse_client=syn,
synapse_client=synapse_client,
)
return self.fill_from_dict(synapse_agent_session=session_response)

Expand All @@ -178,66 +181,62 @@ async def get_async(
async def update_async(
self,
*,
access_level: AgentSessionAccessLevel,
synapse_client: Optional[Synapse] = None,
) -> "AgentSession":
"""Updates an agent session. Only updates to the access level are currently supported.
"""Updates an agent session.
Only updates to the access level are currently supported.
Arguments:
synapse_client: The Synapse client to use for the request. If None, the default client will be used.
Returns:
The updated AgentSession object.
"""
syn = Synapse.get_client(synapse_client=synapse_client)

self.access_level = access_level
session_response = await update_session(
id=self.id,
access_level=self.access_level.value,
synapse_client=syn,
access_level=self.access_level,
synapse_client=synapse_client,
)
return self.fill_from_dict(session_response)

@otel_trace_method(method_to_trace_name=lambda self, **kwargs: f"Prompt: {self.id}")
async def prompt_async(
self,
*,
prompt: str,
enable_trace: bool = False,
newer_than: Optional[int] = None,
print_response: bool = False,
newer_than: Optional[int] = None,
*,
synapse_client: Optional[Synapse] = None,
) -> None:
"""Sends a prompt to the agent and adds the response to the AgentSession's chat history.
Arguments:
prompt: The prompt to send to the agent.
enable_trace: Whether to enable trace for the prompt.
print_response: Whether to print the response to the console.
newer_than: The timestamp to get trace results newer than. Defaults to None (all results).
print: Whether to print the response to the console.
synapse_client: The Synapse client to use for the request. If None, the default client will be used.
"""
syn = Synapse.get_client(synapse_client=synapse_client)
prompt_response = await send_prompt(
id=self.id,
prompt=prompt,
enable_trace=enable_trace,
synapse_client=syn,
synapse_client=synapse_client,
)
prompt_id = prompt_response["token"]

answer_response = await get_response(
prompt_id=prompt_id,
synapse_client=syn,
synapse_client=synapse_client,
)
response = answer_response["responseText"]

if enable_trace:
trace_response = await get_trace(
prompt_id=prompt_id,
newer_than=newer_than,
synapse_client=syn,
synapse_client=synapse_client,
)
trace = trace_response["page"][0]["message"]

Expand All @@ -257,6 +256,7 @@ async def prompt_async(
print(f"TRACE:\n{trace}")


# TODO Add example usage to the docstring
@dataclass
class Agent:
"""Represents a [Synapse Agent Registration](https://rest-docs.synapse.org/rest/org/sagebionetworks/repo/model/agent/AgentRegistration.html)
Expand Down Expand Up @@ -288,8 +288,8 @@ class Agent:
sessions: Dict[str, AgentSession] = field(default_factory=dict)
"""A dictionary of AgentSession objects, keyed by session ID."""

current_session: Optional[str] = None
"""The ID of the current session. Prompts will be sent to this session by default."""
current_session: Optional[AgentSession] = None
"""The current session. Prompts will be sent to this session by default."""

def fill_from_dict(self, agent_registration: Dict[str, str]) -> "Agent":
"""
Expand Down Expand Up @@ -354,10 +354,10 @@ async def get_async(self, *, synapse_client: Optional[Synapse] = None) -> "Agent
)
async def start_session_async(
self,
*,
access_level: Optional[
AgentSessionAccessLevel
] = AgentSessionAccessLevel.PUBLICLY_ACCESSIBLE,
*,
synapse_client: Optional[Synapse] = None,
) -> "AgentSession":
"""Starts an agent session.
Expand All @@ -373,12 +373,11 @@ async def start_session_async(
The new AgentSession object.
"""
access_level = AgentSessionAccessLevel(access_level)
syn = Synapse.get_client(synapse_client=synapse_client)
session = await AgentSession(
agent_registration_id=self.registration_id, access_level=access_level
).start_async(synapse_client=syn)
).start_async(synapse_client=synapse_client)
self.sessions[session.id] = session
self.current_session = session.id
self.current_session = session
return session

@otel_trace_method(
Expand All @@ -387,11 +386,12 @@ async def start_session_async(
async def get_session_async(
self, *, session_id: str, synapse_client: Optional[Synapse] = None
) -> "AgentSession":
syn = Synapse.get_client(synapse_client=synapse_client)
session = await AgentSession(id=session_id).get_async(synapse_client=syn)
session = await AgentSession(id=session_id).get_async(
synapse_client=synapse_client
)
if session.id not in self.sessions:
self.sessions[session.id] = session
self.current_session = session.id
self.current_session = session
return session

@otel_trace_method(
Expand All @@ -400,48 +400,47 @@ async def get_session_async(
async def prompt(
self,
*,
session_id: Optional[str] = None,
prompt: str,
enable_trace: bool = False,
newer_than: Optional[int] = None,
print_response: bool = False,
session: Optional[AgentSession] = None,
newer_than: Optional[int] = None,
synapse_client: Optional[Synapse] = None,
) -> None:
"""Sends a prompt to the agent for the current session.
If no session is currently active, a new session will be started.
Arguments:
session_id: The ID of the session to send the prompt to. If None, the current session will be used.
prompt: The prompt to send to the agent.
enable_trace: Whether to enable trace for the prompt.
newer_than: The timestamp to get trace results newer than. Defaults to None (all results).
print_response: Whether to print the response to the console.
session_id: The ID of the session to send the prompt to. If None, the current session will be used.
newer_than: The timestamp to get trace results newer than. Defaults to None (all results).
synapse_client: The Synapse client to use for the request. If None, the default client will be used.
"""
syn = Synapse.get_client(synapse_client=synapse_client)

# TODO: Iron this out. It's a little confusing.
if session_id:
if session_id not in self.sessions:
await self.get_session_async(session_id=session_id, synapse_client=syn)
# TODO: Iron this out. Make sure we cover all cases.
if session:
if session.id not in self.sessions:
await self.get_session_async(
session_id=session.id, synapse_client=synapse_client
)
else:
self.current_session = session_id
self.current_session = session
else:
if not self.current_session:
await self.start_session_async(synapse_client=syn)
await self.start_session_async(synapse_client=synapse_client)

await self.sessions[self.current_session].prompt_async(
await self.current_session.prompt_async(
prompt=prompt,
enable_trace=enable_trace,
newer_than=newer_than,
print_response=print_response,
synapse_client=syn,
synapse_client=synapse_client,
)

@otel_trace_method(
method_to_trace_name=lambda self, **kwargs: f"Get_Agent_Session_Chat_History: {self.registration_id}"
)
def get_chat_history(self) -> List[AgentPrompt]:
def get_chat_history(self) -> Union[List[AgentPrompt], None]:
"""Gets the chat history for the current session."""
# TODO: Is this the best way to do this?
return self.sessions[self.current_session].chat_history
return self.current_session.chat_history if self.current_session else None

0 comments on commit 9b6e035

Please sign in to comment.