Skip to content

Commit

Permalink
fix case where memory bank is registered without provider_id
Browse files Browse the repository at this point in the history
  • Loading branch information
yanxi0830 committed Oct 17, 2024
1 parent 9fcf5d5 commit f0600a3
Show file tree
Hide file tree
Showing 2 changed files with 23 additions and 2 deletions.
15 changes: 15 additions & 0 deletions llama_stack/apis/memory_banks/client.py
Original file line number Diff line number Diff line change
Expand Up @@ -92,6 +92,21 @@ async def run_main(host: str, port: int, stream: bool):
response = await client.list_memory_banks()
cprint(f"list_memory_banks response={response}", "green")

# register memory bank for the first time
response = await client.register_memory_bank(
VectorMemoryBankDef(
identifier="test_bank2",
embedding_model="all-MiniLM-L6-v2",
chunk_size_in_tokens=512,
overlap_size_in_tokens=64,
)
)
cprint(f"register_memory_bank response={response}", "blue")

# list again after registering
response = await client.list_memory_banks()
cprint(f"list_memory_banks response={response}", "green")


def main(host: str, port: int, stream: bool = True):
asyncio.run(run_main(host, port, stream))
Expand Down
10 changes: 8 additions & 2 deletions llama_stack/distribution/routers/routing_tables.py
Original file line number Diff line number Diff line change
Expand Up @@ -110,10 +110,16 @@ def get_object_by_identifier(
async def register_object(self, obj: RoutableObjectWithProvider):
entries = self.registry.get(obj.identifier, [])
for entry in entries:
if entry.provider_id == obj.provider_id:
print(f"`{obj.identifier}` already registered with `{obj.provider_id}`")
if entry.provider_id == obj.provider_id or not obj.provider_id:
print(
f"`{obj.identifier}` already registered with `{entry.provider_id}`"
)
return

# if provider_id is not specified, we'll pick an arbitrary one from existing entries
if not obj.provider_id and len(self.impls_by_provider_id) > 0:
obj.provider_id = list(self.impls_by_provider_id.keys())[0]

if obj.provider_id not in self.impls_by_provider_id:
raise ValueError(f"Provider `{obj.provider_id}` not found")

Expand Down

0 comments on commit f0600a3

Please sign in to comment.