diff --git a/synapseclient/api/agent_services.py b/synapseclient/api/agent_services.py index 821541a93..7d7d43838 100644 --- a/synapseclient/api/agent_services.py +++ b/synapseclient/api/agent_services.py @@ -9,6 +9,7 @@ if TYPE_CHECKING: from synapseclient import Synapse +from synapseclient.core.constants.concrete_types import AGENT_CHAT_REQUEST from synapseclient.core.exceptions import SynapseTimeoutError @@ -19,6 +20,8 @@ async def register_agent( ) -> Dict[str, Any]: """ Registers an agent with Synapse OR gets existing agent registration. + Sends a request matching + Arguments: cloud_agent_id: The cloud provider ID of the agent to register. @@ -36,10 +39,9 @@ async def register_agent( client = Synapse.get_client(synapse_client=synapse_client) - # Request matching request = { "awsAgentId": cloud_agent_id, - "awsAliasId": cloud_alias_id if cloud_alias_id else None, + "awsAliasId": cloud_alias_id if cloud_alias_id else "TSTALIASID", } return await client.rest_put_async( uri="/agent/registration", body=json.dumps(request) @@ -75,6 +77,8 @@ async def start_session( ) -> Dict[str, Any]: """ Starts a new chat session with an agent. + Sends a request matching + Arguments: access_level: The access level of the agent. @@ -87,7 +91,6 @@ async def start_session( client = Synapse.get_client(synapse_client=synapse_client) - # Request matching request = { "agentAccessLevel": access_level, "agentRegistrationId": agent_registration_id, @@ -125,6 +128,8 @@ async def update_session( ) -> Dict[str, Any]: """ Updates the access level for a chat session. + Sends a request matching + Arguments: id: The ID of the session to update. @@ -137,7 +142,6 @@ async def update_session( client = Synapse.get_client(synapse_client=synapse_client) - # Request matching request = { "sessionId": id, "agentAccessLevel": access_level, @@ -155,6 +159,8 @@ async def send_prompt( ) -> Dict[str, Any]: """ Sends a prompt to an agent starting an asyncronous job. + Sends a request matching + Arguments: id: The ID of the session to send the prompt to. @@ -172,9 +178,8 @@ async def send_prompt( client = Synapse.get_client(synapse_client=synapse_client) - # Request matching request = { - "concreteType": "org.sagebionetworks.repo.model.agent.AgentChatRequest", + "concreteType": AGENT_CHAT_REQUEST, "sessionId": id, "chatText": prompt, "enableTrace": enable_trace, @@ -209,6 +214,8 @@ async def get_response( from synapseclient import Synapse client = Synapse.get_client(synapse_client=synapse_client) + # TODO: Create async compliant version of _waitForAsync and add this logic there + # synapseclient/core/async_utils.py start_time = asyncio.get_event_loop().time() TIMEOUT = 60 @@ -233,10 +240,14 @@ async def get_trace( ) -> Dict[str, Any]: """ Gets the trace of a prompt. + Sends a request matching + Arguments: prompt_id: The token of the prompt to get the trace for. newer_than: The timestamp to get trace results newer than. Defaults to None (all results). + Timestamps should be in milliseconds since the epoch per the API documentation. + https://rest-docs.synapse.org/rest/org/sagebionetworks/repo/model/agent/TraceEvent.html synapse_client: If not passed in and caching was not disabled by `Synapse.allow_client_caching(False)` this will use the last created instance from the Synapse class constructor. @@ -249,7 +260,6 @@ async def get_trace( client = Synapse.get_client(synapse_client=synapse_client) - # Request matching request = { "jobId": prompt_id, "newerThanTimestamp": newer_than, diff --git a/synapseclient/core/constants/concrete_types.py b/synapseclient/core/constants/concrete_types.py index f8d4ee442..e2033c030 100644 --- a/synapseclient/core/constants/concrete_types.py +++ b/synapseclient/core/constants/concrete_types.py @@ -68,3 +68,6 @@ # Activity/Provenance USED_URL = "org.sagebionetworks.repo.model.provenance.UsedURL" USED_ENTITY = "org.sagebionetworks.repo.model.provenance.UsedEntity" + +# Agent +AGENT_CHAT_REQUEST = "org.sagebionetworks.repo.model.agent.AgentChatRequest"