Skip to content

Commit

Permalink
[memory refactor][5/n] Migrate all vector_io providers
Browse files Browse the repository at this point in the history
  • Loading branch information
ashwinb committed Jan 22, 2025
1 parent d3c8a0e commit d3fca8e
Show file tree
Hide file tree
Showing 11 changed files with 237 additions and 347 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -413,8 +413,8 @@ async def _run(
session_info = await self.storage.get_session_info(session_id)

# if the session has a memory bank id, let the memory tool use it
if session_info.memory_bank_id:
vector_db_ids.append(session_info.memory_bank_id)
if session_info.vector_db_id:
vector_db_ids.append(session_info.vector_db_id)

yield AgentTurnResponseStreamChunk(
event=AgentTurnResponseEvent(
Expand Down Expand Up @@ -829,7 +829,7 @@ async def handle_documents(
msg = await attachment_message(self.tempdir, url_items)
input_messages.append(msg)
# Since memory is present, add all the data to the memory bank
await self.add_to_session_memory_bank(session_id, documents)
await self.add_to_session_vector_db(session_id, documents)
elif code_interpreter_tool:
# if only code_interpreter is available, we download the URLs to a tempdir
# and attach the path to them as a message to inference with the
Expand All @@ -838,7 +838,7 @@ async def handle_documents(
input_messages.append(msg)
elif memory_tool:
# if only memory is available, we load the data from the URLs and content items to the memory bank
await self.add_to_session_memory_bank(session_id, documents)
await self.add_to_session_vector_db(session_id, documents)
else:
# if no memory or code_interpreter tool is available,
# we try to load the data from the URLs and content items as a message to inference
Expand All @@ -848,31 +848,31 @@ async def handle_documents(
+ await load_data_from_urls(url_items)
)

async def _ensure_memory_bank(self, session_id: str) -> str:
async def _ensure_vector_db(self, session_id: str) -> str:
session_info = await self.storage.get_session_info(session_id)
if session_info is None:
raise ValueError(f"Session {session_id} not found")

if session_info.memory_bank_id is None:
bank_id = f"memory_bank_{session_id}"
if session_info.vector_db_id is None:
vector_db_id = f"vector_db_{session_id}"

# TODO: the semantic for registration is definitely not "creation"
# so we need to fix it if we expect the agent to create a new vector db
# for each session
await self.vector_io_api.register_vector_db(
vector_db_id=bank_id,
vector_db_id=vector_db_id,
embedding_model="all-MiniLM-L6-v2",
)
await self.storage.add_memory_bank_to_session(session_id, bank_id)
await self.storage.add_vector_db_to_session(session_id, vector_db_id)
else:
bank_id = session_info.memory_bank_id
vector_db_id = session_info.vector_db_id

return bank_id
return vector_db_id

async def add_to_session_memory_bank(
async def add_to_session_vector_db(
self, session_id: str, data: List[Document]
) -> None:
vector_db_id = await self._ensure_memory_bank(session_id)
vector_db_id = await self._ensure_vector_db(session_id)
documents = [
RAGDocument(
document_id=str(uuid.uuid4()),
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,7 @@
class AgentSessionInfo(BaseModel):
session_id: str
session_name: str
memory_bank_id: Optional[str] = None
vector_db_id: Optional[str] = None
started_at: datetime


Expand Down Expand Up @@ -52,12 +52,12 @@ async def get_session_info(self, session_id: str) -> Optional[AgentSessionInfo]:

return AgentSessionInfo(**json.loads(value))

async def add_memory_bank_to_session(self, session_id: str, bank_id: str):
async def add_vector_db_to_session(self, session_id: str, vector_db_id: str):
session_info = await self.get_session_info(session_id)
if session_info is None:
raise ValueError(f"Session {session_id} not found")

session_info.memory_bank_id = bank_id
session_info.vector_db_id = vector_db_id
await self.kvstore.set(
key=f"session:{self.agent_id}:{session_id}",
value=session_info.model_dump_json(),
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -29,19 +29,19 @@
SamplingParams,
ToolChoice,
ToolDefinition,
ToolPromptFormat,
UserMessage,
)
from llama_stack.apis.memory import MemoryBank
from llama_stack.apis.memory_banks import BankParams, VectorMemoryBank
from llama_stack.apis.safety import RunShieldResponse
from llama_stack.apis.tools import (
Tool,
ToolDef,
ToolGroup,
ToolHost,
ToolInvocationResult,
ToolPromptFormat,
)
from llama_stack.apis.vector_io import QueryChunksResponse

from llama_stack.providers.inline.agents.meta_reference.agent_instance import (
MEMORY_QUERY_TOOL,
)
Expand Down Expand Up @@ -110,68 +110,22 @@ async def run_shield(
return RunShieldResponse(violation=None)


class MockMemoryAPI:
class MockVectorIOAPI:
def __init__(self):
self.memory_banks = {}
self.documents = {}

async def create_memory_bank(self, name, config, url=None):
bank_id = f"bank_{len(self.memory_banks)}"
bank = MemoryBank(bank_id, name, config, url)
self.memory_banks[bank_id] = bank
self.documents[bank_id] = {}
return bank

async def list_memory_banks(self):
return list(self.memory_banks.values())

async def get_memory_bank(self, bank_id):
return self.memory_banks.get(bank_id)

async def drop_memory_bank(self, bank_id):
if bank_id in self.memory_banks:
del self.memory_banks[bank_id]
del self.documents[bank_id]
return bank_id

async def insert_documents(self, bank_id, documents, ttl_seconds=None):
if bank_id not in self.documents:
raise ValueError(f"Bank {bank_id} not found")
for doc in documents:
self.documents[bank_id][doc.document_id] = doc

async def update_documents(self, bank_id, documents):
if bank_id not in self.documents:
raise ValueError(f"Bank {bank_id} not found")
for doc in documents:
if doc.document_id in self.documents[bank_id]:
self.documents[bank_id][doc.document_id] = doc

async def query_documents(self, bank_id, query, params=None):
if bank_id not in self.documents:
raise ValueError(f"Bank {bank_id} not found")
# Simple mock implementation: return all documents
chunks = [
{"content": doc.content, "token_count": 10, "document_id": doc.document_id}
for doc in self.documents[bank_id].values()
]
scores = [1.0] * len(chunks)
return {"chunks": chunks, "scores": scores}
self.chunks = {}

async def get_documents(self, bank_id, document_ids):
if bank_id not in self.documents:
raise ValueError(f"Bank {bank_id} not found")
return [
self.documents[bank_id][doc_id]
for doc_id in document_ids
if doc_id in self.documents[bank_id]
]
async def insert_chunks(self, vector_db_id, chunks, ttl_seconds=None):
for chunk in chunks:
metadata = chunk.metadata
self.chunks[vector_db_id][metadata["document_id"]] = chunk

async def delete_documents(self, bank_id, document_ids):
if bank_id not in self.documents:
raise ValueError(f"Bank {bank_id} not found")
for doc_id in document_ids:
self.documents[bank_id].pop(doc_id, None)
async def query_chunks(self, vector_db_id, query, params=None):
if vector_db_id not in self.chunks:
raise ValueError(f"Bank {vector_db_id} not found")

chunks = list(self.chunks[vector_db_id].values())
scores = [1.0] * len(chunks)
return QueryChunksResponse(chunks=chunks, scores=scores)


class MockToolGroupsAPI:
Expand Down Expand Up @@ -241,31 +195,6 @@ async def invoke_tool(self, tool_name: str, args: dict) -> ToolInvocationResult:
return ToolInvocationResult(content={"result": "Mock tool result"})


class MockMemoryBanksAPI:
async def list_memory_banks(self) -> List[MemoryBank]:
return []

async def get_memory_bank(self, memory_bank_id: str) -> Optional[MemoryBank]:
return None

async def register_memory_bank(
self,
memory_bank_id: str,
params: BankParams,
provider_id: Optional[str] = None,
provider_memory_bank_id: Optional[str] = None,
) -> MemoryBank:
return VectorMemoryBank(
identifier=memory_bank_id,
provider_resource_id=provider_memory_bank_id or memory_bank_id,
embedding_model="mock_model",
chunk_size_in_tokens=512,
)

async def unregister_memory_bank(self, memory_bank_id: str) -> None:
pass


@pytest.fixture
def mock_inference_api():
return MockInferenceAPI()
Expand All @@ -277,8 +206,8 @@ def mock_safety_api():


@pytest.fixture
def mock_memory_api():
return MockMemoryAPI()
def mock_vector_io_api():
return MockVectorIOAPI()


@pytest.fixture
Expand All @@ -291,17 +220,11 @@ def mock_tool_runtime_api():
return MockToolRuntimeAPI()


@pytest.fixture
def mock_memory_banks_api():
return MockMemoryBanksAPI()


@pytest.fixture
async def get_agents_impl(
mock_inference_api,
mock_safety_api,
mock_memory_api,
mock_memory_banks_api,
mock_vector_io_api,
mock_tool_runtime_api,
mock_tool_groups_api,
):
Expand All @@ -314,8 +237,7 @@ async def get_agents_impl(
),
inference_api=mock_inference_api,
safety_api=mock_safety_api,
memory_api=mock_memory_api,
memory_banks_api=mock_memory_banks_api,
vector_io_api=mock_vector_io_api,
tool_runtime_api=mock_tool_runtime_api,
tool_groups_api=mock_tool_groups_api,
)
Expand Down Expand Up @@ -484,7 +406,7 @@ async def test_chat_agent_tools(
toolgroups_for_turn=[
AgentToolGroupWithArgs(
name=MEMORY_TOOLGROUP,
args={"memory_banks": ["test_memory_bank"]},
args={"vector_dbs": ["test_vector_db"]},
)
]
)
Expand Down
Loading

0 comments on commit d3fca8e

Please sign in to comment.