Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[bugfix] fix case for agent when memory bank registered without specifying provider_id #264

Merged
merged 3 commits into from
Oct 18, 2024
Merged
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
15 changes: 15 additions & 0 deletions llama_stack/apis/memory_banks/client.py
Original file line number Diff line number Diff line change
Expand Up @@ -92,6 +92,21 @@ async def run_main(host: str, port: int, stream: bool):
response = await client.list_memory_banks()
cprint(f"list_memory_banks response={response}", "green")

# register memory bank for the first time
response = await client.register_memory_bank(
VectorMemoryBankDef(
identifier="test_bank2",
embedding_model="all-MiniLM-L6-v2",
chunk_size_in_tokens=512,
overlap_size_in_tokens=64,
)
)
cprint(f"register_memory_bank response={response}", "blue")

# list again after registering
response = await client.list_memory_banks()
cprint(f"list_memory_banks response={response}", "green")


def main(host: str, port: int, stream: bool = True):
asyncio.run(run_main(host, port, stream))
Expand Down
10 changes: 8 additions & 2 deletions llama_stack/distribution/routers/routing_tables.py
Original file line number Diff line number Diff line change
Expand Up @@ -110,10 +110,16 @@ def get_object_by_identifier(
async def register_object(self, obj: RoutableObjectWithProvider):
entries = self.registry.get(obj.identifier, [])
for entry in entries:
if entry.provider_id == obj.provider_id:
print(f"`{obj.identifier}` already registered with `{obj.provider_id}`")
if entry.provider_id == obj.provider_id or not obj.provider_id:
print(
f"`{obj.identifier}` already registered with `{entry.provider_id}`"
)
return

# if provider_id is not specified, we'll pick an arbitrary one from existing entries
if not obj.provider_id and len(self.impls_by_provider_id) > 0:
obj.provider_id = list(self.impls_by_provider_id.keys())[0]

if obj.provider_id not in self.impls_by_provider_id:
raise ValueError(f"Provider `{obj.provider_id}` not found")

Expand Down