Skip to content

Commit

Permalink
updates agent_services
Browse files Browse the repository at this point in the history
  • Loading branch information
BWMac committed Jan 10, 2025
1 parent f34e9cc commit 05e73f3
Show file tree
Hide file tree
Showing 2 changed files with 20 additions and 7 deletions.
24 changes: 17 additions & 7 deletions synapseclient/api/agent_services.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@
if TYPE_CHECKING:
from synapseclient import Synapse

from synapseclient.core.constants.concrete_types import AGENT_CHAT_REQUEST
from synapseclient.core.exceptions import SynapseTimeoutError


Expand All @@ -19,6 +20,8 @@ async def register_agent(
) -> Dict[str, Any]:
"""
Registers an agent with Synapse OR gets existing agent registration.
Sends a request matching
<https://rest-docs.synapse.org/rest/org/sagebionetworks/repo/model/agent/AgentRegistrationRequest.html>
Arguments:
cloud_agent_id: The cloud provider ID of the agent to register.
Expand All @@ -36,10 +39,9 @@ async def register_agent(

client = Synapse.get_client(synapse_client=synapse_client)

# Request matching <https://rest-docs.synapse.org/rest/org/sagebionetworks/repo/model/agent/AgentRegistrationRequest.html>
request = {
"awsAgentId": cloud_agent_id,
"awsAliasId": cloud_alias_id if cloud_alias_id else None,
"awsAliasId": cloud_alias_id if cloud_alias_id else "TSTALIASID",
}
return await client.rest_put_async(
uri="/agent/registration", body=json.dumps(request)
Expand Down Expand Up @@ -75,6 +77,8 @@ async def start_session(
) -> Dict[str, Any]:
"""
Starts a new chat session with an agent.
Sends a request matching
<https://rest-docs.synapse.org/rest/org/sagebionetworks/repo/model/agent/CreateAgentSessionRequest.html>
Arguments:
access_level: The access level of the agent.
Expand All @@ -87,7 +91,6 @@ async def start_session(

client = Synapse.get_client(synapse_client=synapse_client)

# Request matching <https://rest-docs.synapse.org/rest/org/sagebionetworks/repo/model/agent/CreateAgentSessionRequest.html>
request = {
"agentAccessLevel": access_level,
"agentRegistrationId": agent_registration_id,
Expand Down Expand Up @@ -125,6 +128,8 @@ async def update_session(
) -> Dict[str, Any]:
"""
Updates the access level for a chat session.
Sends a request matching
<https://rest-docs.synapse.org/rest/org/sagebionetworks/repo/model/agent/UpdateAgentSessionRequest.html>
Arguments:
id: The ID of the session to update.
Expand All @@ -137,7 +142,6 @@ async def update_session(

client = Synapse.get_client(synapse_client=synapse_client)

# Request matching <https://rest-docs.synapse.org/rest/org/sagebionetworks/repo/model/agent/UpdateAgentSessionRequest.html>
request = {
"sessionId": id,
"agentAccessLevel": access_level,
Expand All @@ -155,6 +159,8 @@ async def send_prompt(
) -> Dict[str, Any]:
"""
Sends a prompt to an agent starting an asyncronous job.
Sends a request matching
<https://rest-docs.synapse.org/rest/org/sagebionetworks/repo/model/agent/AgentChatRequest.html>
Arguments:
id: The ID of the session to send the prompt to.
Expand All @@ -172,9 +178,8 @@ async def send_prompt(

client = Synapse.get_client(synapse_client=synapse_client)

# Request matching <https://rest-docs.synapse.org/rest/org/sagebionetworks/repo/model/agent/AgentChatRequest.html>
request = {
"concreteType": "org.sagebionetworks.repo.model.agent.AgentChatRequest",
"concreteType": AGENT_CHAT_REQUEST,
"sessionId": id,
"chatText": prompt,
"enableTrace": enable_trace,
Expand Down Expand Up @@ -209,6 +214,8 @@ async def get_response(
from synapseclient import Synapse

client = Synapse.get_client(synapse_client=synapse_client)
# TODO: Create async compliant version of _waitForAsync and add this logic there
# synapseclient/core/async_utils.py
start_time = asyncio.get_event_loop().time()
TIMEOUT = 60

Expand All @@ -233,10 +240,14 @@ async def get_trace(
) -> Dict[str, Any]:
"""
Gets the trace of a prompt.
Sends a request matching
<https://rest-docs.synapse.org/rest/org/sagebionetworks/repo/model/agent/TraceEventsRequest.html>
Arguments:
prompt_id: The token of the prompt to get the trace for.
newer_than: The timestamp to get trace results newer than. Defaults to None (all results).
Timestamps should be in milliseconds since the epoch per the API documentation.
https://rest-docs.synapse.org/rest/org/sagebionetworks/repo/model/agent/TraceEvent.html
synapse_client: If not passed in and caching was not disabled by
`Synapse.allow_client_caching(False)` this will use the last created
instance from the Synapse class constructor.
Expand All @@ -249,7 +260,6 @@ async def get_trace(

client = Synapse.get_client(synapse_client=synapse_client)

# Request matching <https://rest-docs.synapse.org/rest/org/sagebionetworks/repo/model/agent/TraceEventsRequest.html>
request = {
"jobId": prompt_id,
"newerThanTimestamp": newer_than,
Expand Down
3 changes: 3 additions & 0 deletions synapseclient/core/constants/concrete_types.py
Original file line number Diff line number Diff line change
Expand Up @@ -68,3 +68,6 @@
# Activity/Provenance
USED_URL = "org.sagebionetworks.repo.model.provenance.UsedURL"
USED_ENTITY = "org.sagebionetworks.repo.model.provenance.UsedEntity"

# Agent
AGENT_CHAT_REQUEST = "org.sagebionetworks.repo.model.agent.AgentChatRequest"

0 comments on commit 05e73f3

Please sign in to comment.