Skip to content

Commit

Permalink
updates convenience functions
Browse files Browse the repository at this point in the history
  • Loading branch information
BWMac committed Jan 7, 2025
1 parent 60dc255 commit 3fd345c
Showing 1 changed file with 67 additions and 34 deletions.
101 changes: 67 additions & 34 deletions synapseclient/api/agent_services.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,9 +4,7 @@

import json
import asyncio
from typing import TYPE_CHECKING, Any, Dict, List, Optional, Union

from async_lru import alru_cache
from typing import TYPE_CHECKING, Any, Dict, Optional

if TYPE_CHECKING:
from synapseclient import Synapse
Expand All @@ -15,26 +13,34 @@


async def register_agent(
request: Dict[str, Any],
cloud_agent_id: str,
cloud_alias_id: Optional[str] = None,
synapse_client: Optional["Synapse"] = None,
) -> Dict[str, Any]:
"""
Registers an agent with Synapse OR gets existing agent registration.
Arguments:
request: The request for the agent matching
<https://rest-docs.synapse.org/rest/org/sagebionetworks/repo/model/agent/AgentRegistrationRequest.html>
cloud_agent_id: The cloud provider ID of the agent to register.
cloud_alias_id: The cloud provider alias ID of the agent to register.
In the Synapse API, this defaults to 'TSTALIASID'.
synapse_client: If not passed in and caching was not disabled by
`Synapse.allow_client_caching(False)` this will use the last created
instance from the Synapse class constructor.
Returns:
The requested agent matching
The registered agent matching
<https://rest-docs.synapse.org/rest/org/sagebionetworks/repo/model/agent/AgentRegistration.html>
"""
from synapseclient import Synapse

client = Synapse.get_client(synapse_client=synapse_client)

# Request matching <https://rest-docs.synapse.org/rest/org/sagebionetworks/repo/model/agent/AgentRegistrationRequest.html>
request = {
"awsAgentId": cloud_agent_id,
"awsAliasId": cloud_alias_id if cloud_alias_id else None,
}
return await client.rest_put_async(
uri="/agent/registration", body=json.dumps(request)
)
Expand Down Expand Up @@ -63,34 +69,41 @@ async def get_agent(


async def start_session(
request: Dict[str, Any],
access_level: str,
registration_id: str,
synapse_client: Optional["Synapse"] = None,
) -> Dict[str, Any]:
"""
Starts a new chat session with an agent.
Arguments:
request: The request for the session matching
<https://rest-docs.synapse.org/rest/org/sagebionetworks/repo/model/agent/CreateAgentSessionRequest.html>
access_level: The access level of the agent.
registration_id: The ID of the agent registration to start the session for.
synapse_client: If not passed in and caching was not disabled by
`Synapse.allow_client_caching(False)` this will use the last created
instance from the Synapse class constructor.
"""
from synapseclient import Synapse

client = Synapse.get_client(synapse_client=synapse_client)

# Request matching <https://rest-docs.synapse.org/rest/org/sagebionetworks/repo/model/agent/CreateAgentSessionRequest.html>
request = {
"accessLevel": access_level,
"agentRegistrationId": registration_id,
}
return await client.rest_post_async(uri="/agent/session", body=json.dumps(request))


async def get_session(
session_id: str,
id: str,
synapse_client: Optional["Synapse"] = None,
) -> Dict[str, Any]:
"""
Gets information about an existing chat session.
Arguments:
session_id: The ID of the session to get.
id: The ID of the session to get.
synapse_client: If not passed in and caching was not disabled by
`Synapse.allow_client_caching(False)` this will use the last created
instance from the Synapse class constructor.
Expand All @@ -102,43 +115,51 @@ async def get_session(
from synapseclient import Synapse

client = Synapse.get_client(synapse_client=synapse_client)
return await client.rest_get_async(uri=f"/agent/session/{session_id}")
return await client.rest_get_async(uri=f"/agent/session/{id}")


async def update_session(
request: Dict[str, Any],
session_id: str,
id: str,
access_level: str,
synapse_client: Optional["Synapse"] = None,
) -> Dict[str, Any]:
"""
Updates the access level for a chat session.
Arguments:
request: The request for the session matching
<https://rest-docs.synapse.org/rest/org/sagebionetworks/repo/model/agent/UpdateAgentSessionRequest.html>
session_id: The ID of the session to update.
id: The ID of the session to update.
access_level: The access level of the agent.
synapse_client: If not passed in and caching was not disabled by
`Synapse.allow_client_caching(False)` this will use the last created
instance from the Synapse class constructor.
"""
from synapseclient import Synapse

client = Synapse.get_client(synapse_client=synapse_client)

# Request matching <https://rest-docs.synapse.org/rest/org/sagebionetworks/repo/model/agent/UpdateAgentSessionRequest.html>
request = {
"sessionId": id,
"accessLevel": access_level,
}
return await client.rest_put_async(
uri=f"/agent/session/{session_id}", body=json.dumps(request)
uri=f"/agent/session/{id}", body=json.dumps(request)
)


async def send_prompt(
request: Dict[str, Any],
id: str,
prompt: str,
enable_trace: bool = False,
synapse_client: Optional["Synapse"] = None,
) -> Dict[str, Any]:
"""
Sends a prompt to an agent starting an asyncronous job.
Arguments:
request: The request for the prompt matching
<https://rest-docs.synapse.org/rest/org/sagebionetworks/repo/model/agent/AgentChatRequest.html>
id: The ID of the session to send the prompt to.
prompt: The prompt to send to the agent.
enable_trace: Whether to enable trace for the prompt. Defaults to False.
synapse_client: If not passed in and caching was not disabled by
`Synapse.allow_client_caching(False)` this will use the last created
instance from the Synapse class constructor.
Expand All @@ -150,28 +171,37 @@ async def send_prompt(
from synapseclient import Synapse

client = Synapse.get_client(synapse_client=synapse_client)

# Request matching <https://rest-docs.synapse.org/rest/org/sagebionetworks/repo/model/agent/AgentChatRequest.html>
request = {
"concreteType": "org.sagebionetworks.repo.model.agent.AgentChatRequest",
"sessionId": id,
"chatText": prompt,
"enableTrace": enable_trace,
}
return await client.rest_post_async(
uri="/agent/chat/async/start", body=json.dumps(request)
)


async def get_response(
prompt_token: str,
prompt_id: str,
synapse_client: Optional["Synapse"] = None,
) -> Dict[str, Any]:
"""
Gets the response to a prompt.
Arguments:
prompt_token: The token of the prompt to get the response for.
prompt_id: The token of the prompt to get the response for.
synapse_client: If not passed in and caching was not disabled by
`Synapse.allow_client_caching(False)` this will use the last created
instance from the Synapse class constructor.
Returns:
The response matching
<https://rest-docs.synapse.org/rest/org/sagebionetworks/repo/model/agent/AgentChatResponse.html>
If the reponse is ready. Else, it will return a reponse matching
If the reponse is ready.
Else, it will return a reponse matching
<https://rest-docs.synapse.org/rest/org/sagebionetworks/repo/model/asynch/AsynchronousJobStatus.html>
Raises:
Expand All @@ -189,26 +219,23 @@ async def get_response(
f"Timeout waiting for response: {TIMEOUT} seconds"
)

response = await client.rest_get_async(
uri=f"/agent/chat/async/get/{prompt_token}"
)
response = await client.rest_get_async(uri=f"/agent/chat/async/get/{prompt_id}")
if response.get("jobState") != "PROCESSING":
return response
await asyncio.sleep(0.5)


async def get_trace(
request: Dict[str, Any],
prompt_token: str,
prompt_id: str,
newer_than: Optional[int] = None,
synapse_client: Optional["Synapse"] = None,
) -> Dict[str, Any]:
"""
Gets the trace of a prompt.
Arguments:
request: The request for the trace matching
<https://rest-docs.synapse.org/rest/org/sagebionetworks/repo/model/agent/TraceEventsRequest.html>
prompt_token: The token of the prompt to get the trace for.
prompt_id: The token of the prompt to get the trace for.
newer_than: The timestamp to get trace results newer than. Defaults to None (all results).
synapse_client: If not passed in and caching was not disabled by
`Synapse.allow_client_caching(False)` this will use the last created
instance from the Synapse class constructor.
Expand All @@ -220,6 +247,12 @@ async def get_trace(
from synapseclient import Synapse

client = Synapse.get_client(synapse_client=synapse_client)

# Request matching <https://rest-docs.synapse.org/rest/org/sagebionetworks/repo/model/agent/TraceEventsRequest.html>
request = {
"jobId": prompt_id,
"newerThanTimestamp": newer_than,
}
return await client.rest_post_async(
uri=f"/agent/chat/trace/{prompt_token}", body=json.dumps(request)
uri=f"/agent/chat/trace/{prompt_id}", body=json.dumps(request)
)

0 comments on commit 3fd345c

Please sign in to comment.