Skip to content

Commit

Permalink
[memory refactor][6/n] Update naming and routes (#839)
Browse files Browse the repository at this point in the history
Making a few small naming changes as per feedback:

- RAGToolRuntime methods are called `insert` and `query` to keep them
more general
- The tool names are changed to non-namespaced forms
`insert_into_memory` and `query_from_memory`
- The REST endpoints are more REST-ful
  • Loading branch information
ashwinb authored Jan 22, 2025
1 parent c9e5578 commit a63a43c
Show file tree
Hide file tree
Showing 11 changed files with 240 additions and 251 deletions.
374 changes: 185 additions & 189 deletions docs/resources/llama-stack-spec.html

Large diffs are not rendered by default.

35 changes: 15 additions & 20 deletions docs/resources/llama-stack-spec.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -1009,7 +1009,7 @@ components:
- vector_db_id
- chunks
type: object
InsertDocumentsRequest:
InsertRequest:
additionalProperties: false
properties:
chunk_size_in_tokens:
Expand Down Expand Up @@ -1299,10 +1299,6 @@ components:
type: string
inserted_context:
$ref: '#/components/schemas/InterleavedContent'
memory_bank_ids:
items:
type: string
type: array
started_at:
format: date-time
type: string
Expand All @@ -1314,11 +1310,13 @@ components:
type: string
turn_id:
type: string
vector_db_ids:
type: string
required:
- turn_id
- step_id
- step_type
- memory_bank_ids
- vector_db_ids
- inserted_context
type: object
Message:
Expand Down Expand Up @@ -1710,7 +1708,7 @@ components:
- gt
- lt
type: string
QueryContextRequest:
QueryRequest:
additionalProperties: false
properties:
content:
Expand All @@ -1723,7 +1721,6 @@ components:
type: array
required:
- content
- query_config
- vector_db_ids
type: object
QuerySpanTreeResponse:
Expand Down Expand Up @@ -5176,7 +5173,7 @@ paths:
description: OK
tags:
- ToolRuntime
/v1/tool-runtime/rag-tool/insert-documents:
/v1/tool-runtime/rag-tool/insert:
post:
parameters:
- description: JSON-encoded provider data which will be made available to the
Expand All @@ -5197,15 +5194,15 @@ paths:
content:
application/json:
schema:
$ref: '#/components/schemas/InsertDocumentsRequest'
$ref: '#/components/schemas/InsertRequest'
required: true
responses:
'200':
description: OK
summary: Index documents so they can be used by the RAG system
tags:
- ToolRuntime
/v1/tool-runtime/rag-tool/query-context:
/v1/tool-runtime/rag-tool/query:
post:
parameters:
- description: JSON-encoded provider data which will be made available to the
Expand All @@ -5226,7 +5223,7 @@ paths:
content:
application/json:
schema:
$ref: '#/components/schemas/QueryContextRequest'
$ref: '#/components/schemas/QueryRequest'
required: true
responses:
'200':
Expand Down Expand Up @@ -5814,9 +5811,8 @@ tags:
- description: <SchemaDefinition schemaRef="#/components/schemas/InsertChunksRequest"
/>
name: InsertChunksRequest
- description: <SchemaDefinition schemaRef="#/components/schemas/InsertDocumentsRequest"
/>
name: InsertDocumentsRequest
- description: <SchemaDefinition schemaRef="#/components/schemas/InsertRequest" />
name: InsertRequest
- name: Inspect
- description: <SchemaDefinition schemaRef="#/components/schemas/InterleavedContent"
/>
Expand Down Expand Up @@ -5943,9 +5939,8 @@ tags:
- description: <SchemaDefinition schemaRef="#/components/schemas/QueryConditionOp"
/>
name: QueryConditionOp
- description: <SchemaDefinition schemaRef="#/components/schemas/QueryContextRequest"
/>
name: QueryContextRequest
- description: <SchemaDefinition schemaRef="#/components/schemas/QueryRequest" />
name: QueryRequest
- description: <SchemaDefinition schemaRef="#/components/schemas/QuerySpanTreeResponse"
/>
name: QuerySpanTreeResponse
Expand Down Expand Up @@ -6245,7 +6240,7 @@ x-tagGroups:
- ImageDelta
- InferenceStep
- InsertChunksRequest
- InsertDocumentsRequest
- InsertRequest
- InterleavedContent
- InterleavedContentItem
- InvokeToolRequest
Expand Down Expand Up @@ -6290,7 +6285,7 @@ x-tagGroups:
- QueryChunksResponse
- QueryCondition
- QueryConditionOp
- QueryContextRequest
- QueryRequest
- QuerySpanTreeResponse
- QuerySpansResponse
- QueryTracesResponse
Expand Down
10 changes: 5 additions & 5 deletions llama_stack/apis/tools/rag_tool.py
Original file line number Diff line number Diff line change
Expand Up @@ -74,8 +74,8 @@ class RAGQueryConfig(BaseModel):
@runtime_checkable
@trace_protocol
class RAGToolRuntime(Protocol):
@webmethod(route="/tool-runtime/rag-tool/insert-documents", method="POST")
async def insert_documents(
@webmethod(route="/tool-runtime/rag-tool/insert", method="POST")
async def insert(
self,
documents: List[RAGDocument],
vector_db_id: str,
Expand All @@ -84,12 +84,12 @@ async def insert_documents(
"""Index documents so they can be used by the RAG system"""
...

@webmethod(route="/tool-runtime/rag-tool/query-context", method="POST")
async def query_context(
@webmethod(route="/tool-runtime/rag-tool/query", method="POST")
async def query(
self,
content: InterleavedContent,
query_config: RAGQueryConfig,
vector_db_ids: List[str],
query_config: Optional[RAGQueryConfig] = None,
) -> RAGQueryResult:
"""Query the RAG system for context; typically invoked by the agent"""
...
2 changes: 1 addition & 1 deletion llama_stack/apis/vector_io/vector_io.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,7 +38,7 @@ def get_vector_db(self, vector_db_id: str) -> Optional[VectorDB]: ...
class VectorIO(Protocol):
vector_db_store: VectorDBStore

# this will just block now until documents are inserted, but it should
# this will just block now until chunks are inserted, but it should
# probably return a Job instance which can be polled for completion
@webmethod(route="/vector-io/insert", method="POST")
async def insert_chunks(
Expand Down
19 changes: 9 additions & 10 deletions llama_stack/distribution/routers/routers.py
Original file line number Diff line number Diff line change
Expand Up @@ -414,25 +414,25 @@ def __init__(
) -> None:
self.routing_table = routing_table

async def query_context(
async def query(
self,
content: InterleavedContent,
query_config: RAGQueryConfig,
vector_db_ids: List[str],
query_config: Optional[RAGQueryConfig] = None,
) -> RAGQueryResult:
return await self.routing_table.get_provider_impl(
"rag_tool.query_context"
).query_context(content, query_config, vector_db_ids)
"query_from_memory"
).query(content, vector_db_ids, query_config)

async def insert_documents(
async def insert(
self,
documents: List[RAGDocument],
vector_db_id: str,
chunk_size_in_tokens: int = 512,
) -> None:
return await self.routing_table.get_provider_impl(
"rag_tool.insert_documents"
).insert_documents(documents, vector_db_id, chunk_size_in_tokens)
"insert_into_memory"
).insert(documents, vector_db_id, chunk_size_in_tokens)

def __init__(
self,
Expand All @@ -441,10 +441,9 @@ def __init__(
self.routing_table = routing_table

# HACK ALERT this should be in sync with "get_all_api_endpoints()"
# TODO: make sure rag_tool vs builtin::memory is correct everywhere
self.rag_tool = self.RagToolImpl(routing_table)
setattr(self, "rag_tool.query_context", self.rag_tool.query_context)
setattr(self, "rag_tool.insert_documents", self.rag_tool.insert_documents)
for method in ("query", "insert"):
setattr(self, f"rag_tool.{method}", getattr(self.rag_tool, method))

async def initialize(self) -> None:
pass
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -84,7 +84,7 @@ def make_random_string(length: int = 8):


TOOLS_ATTACHMENT_KEY_REGEX = re.compile(r"__tools_attachment__=(\{.*?\})")
MEMORY_QUERY_TOOL = "rag_tool.query_context"
MEMORY_QUERY_TOOL = "query_from_memory"
WEB_SEARCH_TOOL = "web_search"
MEMORY_GROUP = "builtin::memory"

Expand Down Expand Up @@ -432,16 +432,16 @@ async def _run(
)
)
)
result = await self.tool_runtime_api.rag_tool.query_context(
result = await self.tool_runtime_api.rag_tool.query(
content=concat_interleaved_content(
[msg.content for msg in input_messages]
),
vector_db_ids=vector_db_ids,
query_config=RAGQueryConfig(
query_generator_config=DefaultRAGQueryGeneratorConfig(),
max_tokens_in_context=4096,
max_chunks=5,
),
vector_db_ids=vector_db_ids,
)
retrieved_context = result.content

Expand Down Expand Up @@ -882,7 +882,7 @@ async def add_to_session_vector_db(
)
for a in data
]
await self.tool_runtime_api.rag_tool.insert_documents(
await self.tool_runtime_api.rag_tool.insert(
documents=documents,
vector_db_id=vector_db_id,
chunk_size_in_tokens=512,
Expand Down
11 changes: 6 additions & 5 deletions llama_stack/providers/inline/tool_runtime/memory/memory.py
Original file line number Diff line number Diff line change
Expand Up @@ -61,7 +61,7 @@ async def initialize(self):
async def shutdown(self):
pass

async def insert_documents(
async def insert(
self,
documents: List[RAGDocument],
vector_db_id: str,
Expand All @@ -87,15 +87,16 @@ async def insert_documents(
vector_db_id=vector_db_id,
)

async def query_context(
async def query(
self,
content: InterleavedContent,
query_config: RAGQueryConfig,
vector_db_ids: List[str],
query_config: Optional[RAGQueryConfig] = None,
) -> RAGQueryResult:
if not vector_db_ids:
return RAGQueryResult(content=None)

query_config = query_config or RAGQueryConfig()
query = await generate_rag_query(
query_config.query_generator_config,
content,
Expand Down Expand Up @@ -159,11 +160,11 @@ async def list_runtime_tools(
# encountering fatals.
return [
ToolDef(
name="rag_tool.query_context",
name="query_from_memory",
description="Retrieve context from memory",
),
ToolDef(
name="rag_tool.insert_documents",
name="insert_into_memory",
description="Insert documents into memory",
),
]
Expand Down
4 changes: 2 additions & 2 deletions llama_stack/providers/tests/tools/test_tools.py
Original file line number Diff line number Diff line change
Expand Up @@ -96,14 +96,14 @@ async def test_rag_tool(self, tools_stack, sample_documents):
)

# Insert documents into memory
await tools_impl.rag_tool.insert_documents(
await tools_impl.rag_tool.insert(
documents=sample_documents,
vector_db_id="test_bank",
chunk_size_in_tokens=512,
)

# Execute the memory tool
response = await tools_impl.rag_tool.query_context(
response = await tools_impl.rag_tool.query(
content="What are the main topics covered in the documentation?",
vector_db_ids=["test_bank"],
)
Expand Down
20 changes: 9 additions & 11 deletions llama_stack/providers/tests/vector_io/test_vector_store.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,11 +11,9 @@

import pytest

from llama_stack.providers.utils.memory.vector_store import (
content_from_doc,
MemoryBankDocument,
URL,
)
from llama_stack.apis.tools import RAGDocument

from llama_stack.providers.utils.memory.vector_store import content_from_doc, URL

DUMMY_PDF_PATH = Path(os.path.abspath(__file__)).parent / "fixtures" / "dummy.pdf"

Expand All @@ -41,33 +39,33 @@ class TestVectorStore:
@pytest.mark.asyncio
async def test_returns_content_from_pdf_data_uri(self):
data_uri = data_url_from_file(DUMMY_PDF_PATH)
doc = MemoryBankDocument(
doc = RAGDocument(
document_id="dummy",
content=data_uri,
mime_type="application/pdf",
metadata={},
)
content = await content_from_doc(doc)
assert content == "Dummy PDF file"
assert content == "Dumm y PDF file"

@pytest.mark.asyncio
async def test_downloads_pdf_and_returns_content(self):
# Using GitHub to host the PDF file
url = "https://raw.githubusercontent.com/meta-llama/llama-stack/da035d69cfca915318eaf485770a467ca3c2a238/llama_stack/providers/tests/memory/fixtures/dummy.pdf"
doc = MemoryBankDocument(
doc = RAGDocument(
document_id="dummy",
content=url,
mime_type="application/pdf",
metadata={},
)
content = await content_from_doc(doc)
assert content == "Dummy PDF file"
assert content == "Dumm y PDF file"

@pytest.mark.asyncio
async def test_downloads_pdf_and_returns_content_with_url_object(self):
# Using GitHub to host the PDF file
url = "https://raw.githubusercontent.com/meta-llama/llama-stack/da035d69cfca915318eaf485770a467ca3c2a238/llama_stack/providers/tests/memory/fixtures/dummy.pdf"
doc = MemoryBankDocument(
doc = RAGDocument(
document_id="dummy",
content=URL(
uri=url,
Expand All @@ -76,4 +74,4 @@ async def test_downloads_pdf_and_returns_content_with_url_object(self):
metadata={},
)
content = await content_from_doc(doc)
assert content == "Dummy PDF file"
assert content == "Dumm y PDF file"
4 changes: 2 additions & 2 deletions tests/client-sdk/agents/test_agents.py
Original file line number Diff line number Diff line change
Expand Up @@ -292,7 +292,7 @@ def test_rag_agent(llama_stack_client, agent_config):
embedding_model="all-MiniLM-L6-v2",
embedding_dimension=384,
)
llama_stack_client.tool_runtime.rag_tool.insert_documents(
llama_stack_client.tool_runtime.rag_tool.insert(
documents=documents,
vector_db_id=vector_db_id,
chunk_size_in_tokens=512,
Expand Down Expand Up @@ -321,4 +321,4 @@ def test_rag_agent(llama_stack_client, agent_config):
)
logs = [str(log) for log in EventLogger().log(response) if log is not None]
logs_str = "".join(logs)
assert "Tool:rag_tool.query_context" in logs_str
assert "Tool:query_from_memory" in logs_str
Loading

0 comments on commit a63a43c

Please sign in to comment.