diff --git a/llama_stack/providers/tests/agents/provider_config_example.yaml b/llama_stack/providers/tests/agents/provider_config_example.yaml index 5b643590c8..58f05e29ae 100644 --- a/llama_stack/providers/tests/agents/provider_config_example.yaml +++ b/llama_stack/providers/tests/agents/provider_config_example.yaml @@ -31,4 +31,4 @@ providers: persistence_store: namespace: null type: sqlite - db_path: /Users/ashwin/.llama/runtime/kvstore.db + db_path: ~/.llama/runtime/kvstore.db diff --git a/llama_stack/providers/tests/agents/test_agents.py b/llama_stack/providers/tests/agents/test_agents.py index edcc6adeab..6774d3f1fc 100644 --- a/llama_stack/providers/tests/agents/test_agents.py +++ b/llama_stack/providers/tests/agents/test_agents.py @@ -64,6 +64,24 @@ def search_query_messages(): ] +@pytest.fixture +def attachment_message(): + return [ + UserMessage( + content="I am attaching some documentation for Torchtune. Help me answer questions I will ask next.", + ), + ] + + +@pytest.fixture +def query_attachment_messages(): + return [ + UserMessage( + content="What are the top 5 topics that were explained? Only list succinct bullet points." + ), + ] + + @pytest.mark.asyncio async def test_create_agent_turn(agents_settings, sample_messages): agents_impl = agents_settings["impl"] @@ -123,6 +141,89 @@ async def test_create_agent_turn(agents_settings, sample_messages): assert len(final_event.turn.output_message.content) > 0 +@pytest.mark.asyncio +async def test_rag_agent_as_attachments( + agents_settings, attachment_message, query_attachment_messages +): + urls = [ + "memory_optimizations.rst", + "chat.rst", + "llama3.rst", + "datasets.rst", + "qat_finetune.rst", + "lora_finetune.rst", + ] + + attachments = [ + Attachment( + content=f"https://raw.githubusercontent.com/pytorch/torchtune/main/docs/source/tutorials/{url}", + mime_type="text/plain", + ) + for i, url in enumerate(urls) + ] + + agents_impl = agents_settings["impl"] + + agent_config = AgentConfig( + model=agents_settings["common_params"]["model"], + instructions=agents_settings["common_params"]["instructions"], + enable_session_persistence=True, + sampling_params=SamplingParams(temperature=0.7, top_p=0.95), + input_shields=[], + output_shields=[], + tools=[ + MemoryToolDefinition( + memory_bank_configs=[], + query_generator_config={ + "type": "default", + "sep": " ", + }, + max_tokens_in_context=4096, + max_chunks=10, + ), + ], + max_infer_iters=5, + ) + + create_response = await agents_impl.create_agent(agent_config) + agent_id = create_response.agent_id + + # Create a session + session_create_response = await agents_impl.create_agent_session( + agent_id, "Test Session" + ) + session_id = session_create_response.session_id + + # Create and execute a turn + turn_request = dict( + agent_id=agent_id, + session_id=session_id, + messages=attachment_message, + attachments=attachments, + stream=True, + ) + + turn_response = [ + chunk async for chunk in agents_impl.create_agent_turn(**turn_request) + ] + + assert len(turn_response) > 0 + + # Create a second turn querying the agent + turn_request = dict( + agent_id=agent_id, + session_id=session_id, + messages=query_attachment_messages, + stream=True, + ) + + turn_response = [ + chunk async for chunk in agents_impl.create_agent_turn(**turn_request) + ] + + assert len(turn_response) > 0 + + @pytest.mark.asyncio async def test_create_agent_turn_with_brave_search( agents_settings, search_query_messages