Skip to content

Commit

Permalink
[memory refactor][3/n] Introduce RAGToolRuntime as a specialized sub-…
Browse files Browse the repository at this point in the history
…protocol (#832)

See #827 for the broader
design.

Third part:
- we need to make `tool_runtime.rag_tool.query_context()` and
`tool_runtime.rag_tool.insert_documents()` methods work smoothly with
complete type safety. To that end, we introduce a sub-resource path
`tool-runtime/rag-tool/` and make changes to the resolver to make things
work.
- the PR updates the agents implementation to directly call these typed
APIs for memory accesses rather than going through the complex, untyped
"invoke_tool" API. the code looks much nicer and simpler (expectedly.)
- there are a number of hacks in the server resolver implementation
still, we will live with some and fix some

Note that we must make sure the client SDKs are able to handle this
subresource complexity also. Stainless has support for subresources, so
this should be possible but beware.

## Test Plan

Our RAG test is sad (doesn't actually test for actual RAG output) but I
verified that the implementation works. I will work on fixing the RAG
test afterwards.

```bash
pytest -s -v tests/agents/test_agents.py -k "rag and together" --safety-shield=meta-llama/Llama-Guard-3-8B
```
  • Loading branch information
ashwinb authored Jan 22, 2025
1 parent 78a481b commit 1a74904
Show file tree
Hide file tree
Showing 33 changed files with 1,666 additions and 1,363 deletions.
6 changes: 6 additions & 0 deletions docs/openapi_generator/pyopenapi/operations.py
Original file line number Diff line number Diff line change
Expand Up @@ -172,10 +172,16 @@ def _get_endpoint_functions(
def _get_defining_class(member_fn: str, derived_cls: type) -> type:
"Find the class in which a member function is first defined in a class inheritance hierarchy."

# This import must be dynamic here
from llama_stack.apis.tools import RAGToolRuntime, ToolRuntime

# iterate in reverse member resolution order to find most specific class first
for cls in reversed(inspect.getmro(derived_cls)):
for name, _ in inspect.getmembers(cls, inspect.isfunction):
if name == member_fn:
# HACK ALERT
if cls == RAGToolRuntime:
return ToolRuntime
return cls

raise ValidationError(
Expand Down
Loading

0 comments on commit 1a74904

Please sign in to comment.